Skip to content

Commit 62d7f29

Browse files
fix: handle truncated batch stream with sequential GET fallback
- catch StopIteration during batch result iteration and fall back to individual GET requests instead of crashing the DataLoader worker - use batch_stream_failed flag to skip dead iterator for remaining objects - reuse existing _get_object_from_moss_in() retry path for recovery - update test to verify fallback behavior instead of expecting crash Signed-off-by: Abhishek Gaikwad <gaikwadabhishek1997@gmail.com>
1 parent 8feed52 commit 62d7f29

2 files changed

Lines changed: 60 additions & 32 deletions

File tree

lhotse/ais/batch_loader.py

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def __call__(self, cuts: CutSet) -> CutSet:
133133
logger.warning(
134134
f"AIStore batch.get() failed: {e}. Falling back to sequential GET requests."
135135
)
136+
136137
# Fallback: make sequential GET requests for each object in the batch
137138
# Use a generator to maintain consistency with batch.get() which returns an iterator
138139
def sequential_get():
@@ -153,33 +154,49 @@ def sequential_get():
153154

154155
# Apply the received data back into each manifest that had a URL
155156
request_idx = 0
157+
batch_stream_failed = False
156158
for manifest, has_url in manifest_list:
157159
if has_url:
158160
info = None
159161
content = None
160162

161-
try:
162-
info, content = next(batch_result)
163-
except StopIteration:
164-
raise AISBatchLoaderError(
165-
"Batch result iterator exhausted prematurely. "
166-
f"Expected more objects for manifests with URLs."
167-
)
168-
except TimeoutError as e:
169-
# Timeout occurred - recover the request info from saved_requests_list
170-
logger.warning(
171-
f"Timeout while fetching batch result at index {request_idx}: {e}. "
172-
f"Falling back to direct AIStore API call."
173-
)
174-
175-
if request_idx < len(saved_requests_list):
176-
info = saved_requests_list[request_idx]
177-
content = b"" # Mark as empty to trigger retry
178-
else:
179-
raise AISBatchLoaderError(
180-
f"Timeout at request index {request_idx}, but cannot recover: "
181-
f"index out of range for saved_requests_list (len={len(saved_requests_list)})"
182-
) from e
163+
if batch_stream_failed:
164+
# Batch stream already broke — go straight to individual GET
165+
info = saved_requests_list[request_idx]
166+
content = b"" # trigger retry below
167+
else:
168+
try:
169+
info, content = next(batch_result)
170+
except StopIteration:
171+
# Batch stream was truncated (e.g., connection reset mid-tar).
172+
# Fall back to individual GET for this and all remaining objects.
173+
batch_stream_failed = True
174+
logger.warning(
175+
f"Batch stream truncated at index {request_idx}/{len(saved_requests_list)}. "
176+
f"Falling back to direct AIStore API calls for remaining objects."
177+
)
178+
if request_idx < len(saved_requests_list):
179+
info = saved_requests_list[request_idx]
180+
content = b"" # trigger retry below
181+
else:
182+
raise AISBatchLoaderError(
183+
f"Batch stream truncated at index {request_idx}, but cannot recover: "
184+
f"index out of range for saved_requests_list (len={len(saved_requests_list)})"
185+
)
186+
except TimeoutError as e:
187+
# Timeout occurred - recover the request info from saved_requests_list
188+
logger.warning(
189+
f"Timeout while fetching batch result at index {request_idx}: {e}. "
190+
f"Falling back to direct AIStore API call."
191+
)
192+
if request_idx < len(saved_requests_list):
193+
info = saved_requests_list[request_idx]
194+
content = b"" # Mark as empty to trigger retry
195+
else:
196+
raise AISBatchLoaderError(
197+
f"Timeout at request index {request_idx}, but cannot recover: "
198+
f"index out of range for saved_requests_list (len={len(saved_requests_list)})"
199+
) from e
183200

184201
# Retry with direct API call if content is empty (from timeout or actual empty response)
185202
if content == b"":

test/cut/test_ais_batch_loader.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
These tests use mocking to simulate AIStore client behavior,
55
allowing them to run in CI environments without AIStore infrastructure.
66
"""
7+
78
from unittest.mock import MagicMock, patch
89

910
import numpy as np
@@ -968,35 +969,45 @@ def test_fallback_failure_raises_error(
968969
loader(cuts)
969970

970971
@patch("lhotse.ais.batch_loader.get_aistore_client")
971-
def test_iterator_exhausted_raises_error(
972+
def test_iterator_exhausted_falls_back_to_sequential(
972973
self, mock_get_client, cut_with_url_recording
973974
):
974-
"""Test that iterator exhaustion raises AISBatchLoaderError."""
975+
"""Test that iterator exhaustion falls back to individual GET requests."""
975976
client = MagicMock()
976977
batch = MagicMock()
977978

978979
# Track add() calls
979980
add_count = []
980981
batch.add.side_effect = lambda *args, **kwargs: add_count.append(1)
981982

982-
# Mock batch.get() to return fewer items than expected
983+
# Mock batch.get() to return fewer items than expected (empty iterator)
983984
def mock_batch_get():
984-
# Return nothing even though we expect 1 item
985985
return iter([])
986986

987987
batch.get.side_effect = lambda: mock_batch_get()
988+
batch.requests_list = [
989+
MagicMock(
990+
obj_name="test.wav", bck="test-bucket", provider="ais", archpath=""
991+
)
992+
]
988993
client.batch.return_value = batch
989-
client.bucket.return_value = MagicMock()
994+
mock_bucket = MagicMock()
995+
mock_obj = MagicMock()
996+
mock_reader = MagicMock()
997+
mock_reader.read_all.return_value = b"\x00" * 16000
998+
mock_obj.get_reader.return_value = mock_reader
999+
mock_bucket.object.return_value = mock_obj
1000+
client.bucket.return_value = mock_bucket
9901001
mock_get_client.return_value = (client, None)
9911002

9921003
loader = AISBatchLoader()
9931004
cuts = CutSet.from_cuts([cut_with_url_recording])
9941005

995-
# Should raise AISBatchLoaderError when iterator is exhausted
996-
with pytest.raises(
997-
AISBatchLoaderError, match="Batch result iterator exhausted prematurely"
998-
):
999-
loader(cuts)
1006+
# Should NOT raise — falls back to individual GET
1007+
result = loader(cuts)
1008+
assert result is not None
1009+
# Verify the fallback GET was called
1010+
mock_obj.get_reader.assert_called()
10001011

10011012
@patch("lhotse.ais.batch_loader.get_aistore_client")
10021013
def test_multiple_cuts_with_fallback(self, mock_get_client, cut_with_url_recording):

0 commit comments

Comments
 (0)