Skip to content

Commit dbfa2d1

Browse files
committed
Apply pinned formatting
1 parent 34f7140 commit dbfa2d1

9 files changed

Lines changed: 156 additions & 63 deletions

File tree

lhotse/cut/set.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4177,7 +4177,9 @@ def _mix_one(self, cut: Cut, rng: random.Random) -> Cut:
41774177
# Actual mixing
41784178
to_mix = self._next_mix_in_cut(rng)
41794179
to_mix = self._maybe_truncate_cut(to_mix, target_mixed_duration, rng)
4180-
mixed = cut.mix(other=to_mix, snr=cut_snr, preserve_id=self.preserve_id, tag=self.tag)
4180+
mixed = cut.mix(
4181+
other=to_mix, snr=cut_snr, preserve_id=self.preserve_id, tag=self.tag
4182+
)
41814183
# Did the user specify a duration?
41824184
# If yes, we will ensure that shorter cuts have more noise mixed in
41834185
# to "pad" them with at the end.

lhotse/lazy.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,11 @@ def state_dict(self) -> dict:
611611
# ``shuffle`` and ``seed`` are surfaced in the dict for inspection
612612
# and forward-compat — they are not consumed on load (the iterator
613613
# is reconstructed with the same constructor args).
614-
return {**self._iter_state.state_dict(), "shuffle": self.shuffle, "seed": self.seed}
614+
return {
615+
**self._iter_state.state_dict(),
616+
"shuffle": self.shuffle,
617+
"seed": self.seed,
618+
}
615619

616620
def load_state_dict(self, sd: dict) -> None:
617621
if self.shuffle and "range" not in sd:

lhotse/shar/readers/indexed.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,8 @@ def __init__(
128128
self._lazy = lazy
129129
self.epoch = 0
130130
self._iter_state = PartitionedIndexedIterator(
131-
shuffle=self.shuffle, seed=resolve_seed(self.seed) if isinstance(self.seed, int) else 0
131+
shuffle=self.shuffle,
132+
seed=resolve_seed(self.seed) if isinstance(self.seed, int) else 0,
132133
)
133134

134135
# Build indexed readers for cuts JSONL shards and compute lengths.

test/ais/test_batch_loader_moss_in.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
from aistore.sdk.batch.types import MossIn, MossReq
2222

23-
2423
# ---------------------------------------------------------------------------
2524
# Field presence: the fast path passes specific kwargs to MossIn — assert
2625
# they're all real fields on the current SDK so a rename surfaces here.
@@ -44,9 +43,7 @@ def test_mossin_obj_name_is_required():
4443
"""``obj_name`` is the only required field today; if more get added,
4544
the fast path's kwargs-sparse build silently produces an invalid request."""
4645
required = [
47-
name
48-
for name, info in MossIn.model_fields.items()
49-
if info.is_required()
46+
name for name, info in MossIn.model_fields.items() if info.is_required()
5047
]
5148
assert set(required) <= {"obj_name"}, (
5249
f"MossIn introduced new required fields {set(required) - {'obj_name'}} "
@@ -65,18 +62,39 @@ def test_mossin_obj_name_is_required():
6562

6663
_FAST_KWARG_COMBOS = [
6764
# (description, kwargs)
68-
("url-only (e.g. recording.source.type='url')",
69-
{"obj_name": "audio.wav", "bck": "bkt", "provider": "ais"}),
70-
("url with archpath (tar member)",
71-
{"obj_name": "shard.tar", "bck": "bkt", "provider": "ais", "archpath": "rec1.wav"}),
72-
("shar_ptr (byte-range)",
73-
{"obj_name": "shard.tar", "bck": "bkt", "provider": "ais", "start": 4096, "length": 65536}),
74-
("aws provider variant",
75-
{"obj_name": "shard.tar", "bck": "bkt", "provider": "aws"}),
65+
(
66+
"url-only (e.g. recording.source.type='url')",
67+
{"obj_name": "audio.wav", "bck": "bkt", "provider": "ais"},
68+
),
69+
(
70+
"url with archpath (tar member)",
71+
{
72+
"obj_name": "shard.tar",
73+
"bck": "bkt",
74+
"provider": "ais",
75+
"archpath": "rec1.wav",
76+
},
77+
),
78+
(
79+
"shar_ptr (byte-range)",
80+
{
81+
"obj_name": "shard.tar",
82+
"bck": "bkt",
83+
"provider": "ais",
84+
"start": 4096,
85+
"length": 65536,
86+
},
87+
),
88+
(
89+
"aws provider variant",
90+
{"obj_name": "shard.tar", "bck": "bkt", "provider": "aws"},
91+
),
7692
]
7793

7894

79-
@pytest.mark.parametrize("desc,kwargs", _FAST_KWARG_COMBOS, ids=[c[0] for c in _FAST_KWARG_COMBOS])
95+
@pytest.mark.parametrize(
96+
"desc,kwargs", _FAST_KWARG_COMBOS, ids=[c[0] for c in _FAST_KWARG_COMBOS]
97+
)
8098
def test_model_construct_matches_validating_constructor(desc: str, kwargs: dict):
8199
"""``MossIn.model_construct(**kwargs)`` must serialize identically to
82100
``MossIn(**kwargs)`` — same field values, same defaults for omitted
@@ -119,6 +137,7 @@ def test_mossreq_has_moss_in_list_field():
119137
annotation = MossReq.model_fields["moss_in"].annotation
120138
# typing.List[MossIn] subscripts; check origin is list-like.
121139
import typing as _t
140+
122141
assert _t.get_origin(annotation) is list, (
123142
f"MossReq.moss_in type changed from List[MossIn] to {annotation} "
124143
f"on aistore {aistore.__version__}. Fast path's append() may break."
@@ -141,6 +160,7 @@ def test_batch_requests_list_is_public_accessor():
141160
# Most robust check is source inspection (kept lenient — only flag a hard
142161
# rename of 'moss_in', anything else gets caught by the round-trip tests).
143162
import inspect
163+
144164
src = inspect.getsource(descriptor.fget)
145165
assert "moss_in" in src, (
146166
f"Batch.requests_list source no longer references moss_in on "

test/dataset/sampling/test_sampler_restoring.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,9 @@ def test_load_state_dict_rejects_cross_rank_state():
356356
src = SimpleCutSampler(CUTS, max_duration=10.0, world_size=4, rank=0)
357357
dst = SimpleCutSampler(CUTS, max_duration=10.0, world_size=4, rank=2)
358358
state = src.state_dict()
359-
with pytest.raises(RuntimeError, match="state was saved on rank=0 but is being loaded on rank=2"):
359+
with pytest.raises(
360+
RuntimeError, match="state was saved on rank=0 but is being loaded on rank=2"
361+
):
360362
dst.load_state_dict(state)
361363

362364

test/dataset/test_checkpoint_restore.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,9 @@ def test_mixed_raises(self, cuts):
564564
from lhotse.lazy import LazyIteratorMultiplexer
565565

566566
streaming_leaf = self._make_streaming_leaf()(list(cuts)[:50])
567-
indexed_leaf = _IndexedCutsWithoutGraphOrigin(list(cuts)[50:]) # has_constant_time_access=True per definition above
567+
indexed_leaf = _IndexedCutsWithoutGraphOrigin(
568+
list(cuts)[50:]
569+
) # has_constant_time_access=True per definition above
568570
mux = LazyIteratorMultiplexer(streaming_leaf, indexed_leaf, seed=0)
569571
cs = CutSet(mux)
570572

test/shar/test_indexed_partition.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from lhotse.shar.writers import SharWriter
3333
from lhotse.testing.dummies import DummyManifest
3434

35-
3635
_PARTITION_ENV_KEYS = ("RANK", "WORLD_SIZE", LHOTSE_USE_WORKER_PARTITION)
3736

3837

@@ -91,9 +90,9 @@ def test_indexed_shar_partition_disjoint_and_complete(indexed_shar_dir, world_si
9190
with _env_partition(rank=rank, world_size=world_size):
9291
ids = {c.id for c in LazyIndexedSharIterator(in_dir=indexed_shar_dir)}
9392
for prev in per_rank:
94-
assert prev.isdisjoint(ids), (
95-
f"rank {rank} slice overlaps prior rank: {sorted(prev & ids)}"
96-
)
93+
assert prev.isdisjoint(
94+
ids
95+
), f"rank {rank} slice overlaps prior rank: {sorted(prev & ids)}"
9796
per_rank.append(ids)
9897

9998
union: set = set()
@@ -172,11 +171,16 @@ def test_indexed_shar_partition_works_with_shuffle(indexed_shar_dir, shuffle):
172171
union: set = set()
173172
for rank in range(world_size):
174173
with _env_partition(rank=rank, world_size=world_size):
175-
ids = {c.id for c in LazyIndexedSharIterator(
176-
in_dir=indexed_shar_dir, shuffle=shuffle, seed=42,
177-
)}
178-
assert union.isdisjoint(ids), (
179-
f"rank {rank} overlaps: {sorted(union & ids)} (shuffle={shuffle})"
180-
)
174+
ids = {
175+
c.id
176+
for c in LazyIndexedSharIterator(
177+
in_dir=indexed_shar_dir,
178+
shuffle=shuffle,
179+
seed=42,
180+
)
181+
}
182+
assert union.isdisjoint(
183+
ids
184+
), f"rank {rank} overlaps: {sorted(union & ids)} (shuffle={shuffle})"
181185
union |= ids
182186
assert union == expected

test/test_indexing.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -216,15 +216,15 @@ def worker():
216216

217217
assert not errors, f"workers raised: {errors[:3]}"
218218
assert lengths, "no worker reported a length — barrier wait failed?"
219-
assert all(L == n_records for L in lengths), (
220-
f"workers saw inconsistent lengths {Counter(lengths)}; expected all == {n_records}"
221-
)
219+
assert all(
220+
L == n_records for L in lengths
221+
), f"workers saw inconsistent lengths {Counter(lengths)}; expected all == {n_records}"
222222

223223
# Final .idx must be non-empty + size-aligned to uint64.
224224
idx_size = index_file_path(p).stat().st_size
225-
assert idx_size > 0 and idx_size % 8 == 0, (
226-
f"final .idx is {idx_size} bytes — not a clean uint64 array"
227-
)
225+
assert (
226+
idx_size > 0 and idx_size % 8 == 0
227+
), f"final .idx is {idx_size} bytes — not a clean uint64 array"
228228

229229

230230
# ---------------------------------------------------------------------------
@@ -428,7 +428,9 @@ def test_lazy_shuffled_range_partition_shard_lengths(n, num_shards):
428428
total = 0
429429
for shard_id in range(num_shards):
430430
shard = LazyShuffledRange(n, seed=42, shard_id=shard_id, num_shards=num_shards)
431-
expected_len = max(0, (n - shard_id + num_shards - 1) // num_shards) if n > shard_id else 0
431+
expected_len = (
432+
max(0, (n - shard_id + num_shards - 1) // num_shards) if n > shard_id else 0
433+
)
432434
assert len(shard) == expected_len
433435
total += len(shard)
434436
assert total == n
@@ -1199,7 +1201,9 @@ def fake_open_best(path, mode="r"):
11991201
return original_open_best(target, mode)
12001202
return original_open_best(path, mode)
12011203

1202-
monkeypatch.setattr(indexing_mod, "_open_for_indexed_read", fake_open_for_indexed_read)
1204+
monkeypatch.setattr(
1205+
indexing_mod, "_open_for_indexed_read", fake_open_for_indexed_read
1206+
)
12031207
monkeypatch.setattr(indexing_mod, "open_best", fake_open_best)
12041208

12051209

@@ -1298,9 +1302,7 @@ def test_indexed_tar_reader_remote_data_with_remote_index_path(
12981302
assert str(data_path) == samples[1][0]
12991303

13001304

1301-
def test_read_index_remote_path_is_cached_locally(
1302-
tmp_path, jsonl_file, monkeypatch
1303-
):
1305+
def test_read_index_remote_path_is_cached_locally(tmp_path, jsonl_file, monkeypatch):
13041306
"""Remote index files are downloaded once, cached on disk, and reused
13051307
from the cache on subsequent calls (no second remote fetch)."""
13061308
import lhotse.indexing as indexing_mod

0 commit comments

Comments
 (0)