Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions lhotse/cut/mono.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,13 @@ def from_dict(data: dict) -> "MonoCut":
recording=recording,
supervisions=[SupervisionSegment.from_dict(s) for s in supervision_infos],
)

def __lt__(self, other):
token_length = len(self.supervisions[0].custom['tokens']['text'])
other_token_length = len(other.supervisions[0].custom['tokens']['text'])
if token_length - other_token_length >= 3:
return True
elif token_length - other_token_length <= -3:
return False
else:
return self.duration > other.duration
21 changes: 15 additions & 6 deletions lhotse/dataset/sampling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ class TimeConstraint(SamplingConstraint):
current: Union[int, Seconds] = 0
num_cuts: int = 0
longest_seen: Union[int, float] = 0
longest_seen_text: Union[int, float] = 0
quadratic_duration: Optional[Seconds] = None

def __post_init__(self) -> None:
Expand All @@ -454,26 +455,31 @@ def add(self, example: Cut) -> None:
selecting the right property from the input ``cut`` object.
"""
if self.max_duration is not None:
duration = self._maybe_apply_quadratic_correction(example.duration)
self.current += duration
duration = self._maybe_apply_quadratic_correction(example.duration, quadratic_expand_value=1.5)
try:
text_duration = self._maybe_apply_quadratic_correction(len(example.supervisions[0].custom['tokens']['text']) * 0.046)
except:
text_duration = 0
self.current += duration + text_duration
self.longest_seen = max(self.longest_seen, duration)
self.longest_seen_text = max(self.longest_seen_text, text_duration)
self.num_cuts += 1

def _maybe_apply_quadratic_correction(self, duration: Seconds) -> Seconds:
def _maybe_apply_quadratic_correction(self, duration: Seconds, quadratic_expand_value: Seconds = 1.0) -> Seconds:
if self.quadratic_duration is None:
return duration
# For the quadratic complexity case, we add a term that accounts for
# extra memory occupied by the model. The 1/quadratic_duration term causes
# the effective duration to be doubled when it's equal to quadratic_duration.
return duration + (duration**2) / self.quadratic_duration
return duration + ((duration * quadratic_expand_value)**2) / self.quadratic_duration

def exceeded(self) -> bool:
"""Is the constraint exceeded or not."""
if self.max_cuts is not None and self.num_cuts > self.max_cuts:
return True
if self.max_duration is None:
return False
effective_duration = self.num_cuts * self.longest_seen
effective_duration = self.num_cuts * (self.longest_seen + self.longest_seen_text)
return effective_duration > self.max_duration

def close_to_exceeding(self) -> bool:
Expand All @@ -487,7 +493,7 @@ def close_to_exceeding(self) -> bool:
return True

if self.max_duration is not None:
effective_duration = (self.num_cuts + 1) * self.longest_seen
effective_duration = (self.num_cuts + 1) * (self.longest_seen + self.longest_seen_text)
return effective_duration > self.max_duration
return False

Expand All @@ -499,6 +505,7 @@ def reset(self) -> None:
self.current = 0
self.num_cuts = 0
self.longest_seen = 0
self.longest_seen_text = 0

def measure_length(self, example: Cut) -> float:
return example.duration
Expand All @@ -512,6 +519,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.current = state_dict.pop("current")
self.num_cuts = state_dict.pop("num_cuts")
self.longest_seen = state_dict.pop("longest_seen", 0)
self.longest_seen_text = state_dict.pop("longest_seen_text", 0)
self.quadratic_duration = state_dict.pop("quadratic_duration", None)
# backward compatibility
state_dict.pop("strict", None)
Expand All @@ -537,6 +545,7 @@ def __add__(self, other: "TimeConstraint") -> "TimeConstraint":
current=self.current + other.current,
num_cuts=self.num_cuts + other.num_cuts,
longest_seen=max(self.longest_seen, other.longest_seen),
longest_seen_text=max(self.longest_seen_text, other.longest_seen_text),
quadratic_duration=self.quadratic_duration,
)

Expand Down
22 changes: 16 additions & 6 deletions lhotse/dataset/sampling/dynamic_bucketing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Tuple,
Union,
)
import heapq

import numpy as np
import torch
Expand Down Expand Up @@ -458,8 +459,8 @@ def __init__(
)

# Init: create empty buckets (note: `num_buckets = len(duration_bins) + 1`).
self.buckets: List[Deque[Union[Cut, Tuple[Cut, ...]]]] = [
deque() for _ in range(len(duration_bins) + 1)
self.buckets: List[List[Union[Cut, Tuple[Cut, ...], Tuple[int, Cut]]]] = [
[] for _ in range(len(duration_bins) + 1)
]

def __iter__(self) -> Generator[CutSet, None, None]:
Expand All @@ -485,6 +486,10 @@ def __iter__(self) -> Generator[CutSet, None, None]:
maybe_shuffled = pick_at_random(
maybe_shuffled, rng=self.rng, out_indexes_used=indexes_used
)
else:
maybe_shuffled = pick_at_min_heap(
maybe_shuffled
)
# Sample one batch from that bucket and yield it to the caller.
batcher = DurationBatcher(
maybe_shuffled,
Expand All @@ -504,9 +509,8 @@ def __iter__(self) -> Generator[CutSet, None, None]:
for idx in indexes_used:
del sampling_bucket[idx]
else:
# No shuffling, remove first N
for _ in range(batch_size):
sampling_bucket.popleft()
# already pop by heappop
pass
# Fetch new cuts and add them to appropriate buckets.
self._collect_cuts_in_buckets(batch_size)
except StopIteration:
Expand Down Expand Up @@ -609,7 +613,7 @@ def _collect_cuts_in_buckets(self, n_cuts: int) -> None:
cuts[0] if isinstance(cuts, tuple) else cuts
)
bucket_idx = bisect_right(self.duration_bins, duration)
self.buckets[bucket_idx].append(cuts)
heapq.heappush(self.buckets[bucket_idx], cuts)
except StopIteration:
pass

Expand All @@ -629,6 +633,12 @@ def pick_at_random(
out_indexes_used.append(idx)
yield bucket[idx]

def pick_at_min_heap(
bucket: Sequence[Union[Cut, Tuple[Cut, ...]]],
) -> Generator[Union[Cut, Tuple[Cut, ...]], None, None]:
while bucket:
yield heapq.heappop(bucket)


class BucketsDontHaveEnoughData(Exception):
pass
Expand Down
72 changes: 49 additions & 23 deletions lhotse/shar/readers/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def __iter__(self):
map_fns = self._maybe_shuffle_shards(map_fns)
map_fns = self._maybe_split_for_dataloading(map_fns)

for shard, cut_map_fn in zip(shards, map_fns):
for shard_idx, (shard, cut_map_fn) in enumerate(zip(shards, map_fns)):
# Iterate over cuts for the current shard
cuts = LazyManifestIterator(shard["cuts"])

Expand All @@ -249,33 +249,59 @@ def __iter__(self):
}

# Open every tarfile/jsonl so it's ready for streaming
field_iters = {
field: TarIterator(path)
if extension_contains(".tar", path)
else _jsonl_tar_adaptor(LazyJsonlIterator(path), field=field)
for field, path in field_paths.items()
}
field_iters = {}
for field, path in field_paths.items():
if path is None:
pass
elif extension_contains(".tar", path):
field_iters[field] = TarIterator(path)
else:
for item, pseudo_path in _jsonl_tar_adaptor(LazyJsonlIterator(path), field=field):
field_iters[field] = TarIterator(pseudo_path)
# field_iters = {
# field: TarIterator(path)
# if extension_contains(".tar", path)
# else _jsonl_tar_adaptor(LazyJsonlIterator(path), field=field)
# for field, path in field_paths.items()
# }

# *field_data contains all fields for a single cut (recording, features, array, etc.)
for cut, *field_data in zip(cuts, *field_iters.values()):
for (field, (maybe_manifest, data_path)) in zip(
try:
for cut in self.generate_shard_cuts(cuts, field_iters):
cut.shard_origin = shard["cuts"]
cut.shar_epoch = self.epoch
if cut_map_fn is not None:
cut = cut_map_fn(cut)
yield cut
except OSError as e:
print("WARNING(LazySharIterator): OSError: {}".format(e))
if shard_idx % 100 == 0:
print("INFO(LazySharIterator): EPOCH({}) Shard {}/{} done".format(self.epoch, shard_idx+1, len(shards)))

for cut in self.generate_customized_cuts():
yield cut

self.epoch += 1


def generate_customized_cuts(self):
for cut in []:
yield cut


def generate_shard_cuts(self, cuts, field_iters):
for cut, *field_data in zip(cuts, *field_iters.values()):
for (field, (maybe_manifest, data_path)) in zip(
field_iters.keys(),
field_data,
):
if maybe_manifest is None:
continue # No value available for the current field for this cut.
assert (
):
if maybe_manifest is None:
continue # No value available for the current field for this cut.
assert (
str(data_path.parent / data_path.stem) == cut.id
), f"Mismatched IDs: cut ID is '{cut.id}' but found data with name '{data_path}' fsor field {field}"
setattr(cut, field, maybe_manifest)

cut.shard_origin = shard["cuts"]
cut.shar_epoch = self.epoch
if cut_map_fn is not None:
cut = cut_map_fn(cut)
yield cut

self.epoch += 1
), f"Mismatched IDs: cut ID is '{cut.id}' but found data with name '{data_path}' fsor field {field}"
setattr(cut, field, maybe_manifest)
yield cut

def __len__(self) -> int:
if self._len is None:
Expand Down
6 changes: 5 additions & 1 deletion lhotse/shar/readers/tar.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ def iterate_tarfile_pairwise(
if len(result) == 2:
yield tuple(result)
result = []
result.append(parse_tarinfo(tarinfo, tar_file))
try:
result.append(parse_tarinfo(tarinfo, tar_file))
except:
print("Failed to parse tar info of tar file {}".format(tar_file))
return

if len(result) == 2:
yield tuple(result)
Expand Down