Skip to content

Commit 786a88b

Browse files
committed
fix: _stop()
1 parent c3d5d94 commit 786a88b

File tree

1 file changed

+72
-45
lines changed

1 file changed

+72
-45
lines changed

lavender_data/client/iteration.py

Lines changed: 72 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,15 @@ def __init__(
143143
self._api = _api(self._api_url, self._api_key)
144144

145145
self._bytes = 0
146-
self._complete_thread: Optional[threading.Thread] = None
146+
self._stopped = False
147147

148148
self.id = self._iteration_id
149149

150+
self._complete_thread = threading.Thread(
151+
target=self._keep_complete_indices, daemon=True
152+
)
153+
self._complete_thread.start()
154+
150155
def torch(
151156
self,
152157
pin_memory: bool = False,
@@ -206,7 +211,7 @@ def __len__(self):
206211
return self._total
207212

208213
def _keep_complete_indices(self):
209-
while True:
214+
while not self._stopped:
210215
if len(self._completed_indices) == 0:
211216
time.sleep(0.1)
212217
continue
@@ -215,12 +220,6 @@ def _keep_complete_indices(self):
215220
self.complete(index)
216221

217222
def _complete_last_indices(self):
218-
if self._complete_thread is None:
219-
self._complete_thread = threading.Thread(
220-
target=self._keep_complete_indices, daemon=True
221-
)
222-
self._complete_thread.start()
223-
224223
self._completed_indices.update(self._last_indices)
225224
self._last_indices = set()
226225

@@ -278,6 +277,10 @@ def _get_submitted_result(self, cache_key: str):
278277
except DeserializeException as e:
279278
raise ValueError(f"Failed to deserialize sample: {e}")
280279

280+
def _stop(self):
281+
self._stopped = True
282+
self._complete_thread.join()
283+
281284
def __next__(self):
282285
self._complete_last_indices()
283286

@@ -295,11 +298,18 @@ def __next__(self):
295298
return sample_or_batch
296299

297300
def __iter__(self):
298-
return self
301+
try:
302+
while True:
303+
yield next(self)
304+
finally:
305+
self._stop()
299306

300307
def __getitem__(self, index: int):
301308
return next(self)
302309

310+
def __del__(self):
311+
self._stop()
312+
303313

304314
class AsyncLavenderDataLoader:
305315
def __init__(
@@ -312,102 +322,119 @@ def __init__(
312322
if prefetch_factor < 1:
313323
raise ValueError("prefetch_factor must be greater than 0")
314324

315-
self.dl = dl
316-
self.prefetch_factor = prefetch_factor
317-
self.poll_interval = poll_interval
318-
self.in_order = in_order
319-
self.arrived: list[tuple[int, dict]] = []
320-
self.current = -1
321-
self.stopped = False
322-
self.fetch_threads: list[threading.Thread] = []
325+
self._dl = dl
326+
self._prefetch_factor = prefetch_factor
327+
self._poll_interval = poll_interval
328+
self._in_order = in_order
329+
self._arrived: list[tuple[int, dict]] = []
330+
self._current = -1
331+
self._stopped = False
332+
self._fetch_threads: list[threading.Thread] = []
333+
self._error: Optional[Exception] = None
334+
335+
def _stop(self):
336+
self._stopped = True
337+
for thread in self._fetch_threads:
338+
thread.join()
339+
self._dl._stop()
323340

324341
def _get_submitted_result(self, cache_key: str):
325342
while True:
326-
data = self.dl._get_submitted_result(cache_key)
343+
data = self._dl._get_submitted_result(cache_key)
327344
if data is not None:
328345
return data
329346
else:
330-
time.sleep(self.poll_interval)
347+
time.sleep(self._poll_interval)
331348

332349
def _fetch_one(self):
333350
try:
334-
cache_key = self.dl._submit_next_item()
351+
cache_key = self._dl._submit_next_item()
335352
except LavenderDataApiError as e:
336353
if "No more indices to pop" in str(e):
337-
self.stopped = True
354+
self._stopped = True
338355
return
339356
else:
340357
raise e
341358

342359
data = None
343360
arrived_index = None
344-
while arrived_index is None:
345-
time.sleep(self.poll_interval)
361+
while arrived_index is None and not self._stopped:
362+
time.sleep(self._poll_interval)
346363

347364
try:
348-
data = self.dl._get_submitted_result(cache_key)
365+
data = self._dl._get_submitted_result(cache_key)
349366
if data is not None:
350367
arrived_index = data["_lavender_data_current"]
351368
except StopIteration:
352-
self.stopped = True
369+
self._stopped = True
353370
return
354371
except LavenderDataSampleProcessingError as e:
355-
if self.dl._skip_on_failure:
372+
if self._dl._skip_on_failure:
356373
arrived_index = e.current
357374
else:
358375
raise e
359376

360377
return arrived_index, data
361378

362379
def _keep_fetching(self):
363-
while not self.stopped:
380+
while not self._stopped:
364381
try:
365382
arrived_index, data = self._fetch_one()
366-
self.arrived.append((arrived_index, data))
383+
self._arrived.append((arrived_index, data))
367384
except Exception as e:
368-
warnings.warn(f"Error in fetch thread: {e}")
385+
self._error = e
386+
self._stopped = True
369387

370388
def _start_fetch_threads(self):
371-
for _ in range(self.prefetch_factor):
372-
thread = threading.Thread(target=self._keep_fetching, daemon=True)
389+
for _ in range(self._prefetch_factor):
390+
thread = threading.Thread(target=self._keep_fetching)
373391
thread.start()
374-
self.fetch_threads.append(thread)
392+
self._fetch_threads.append(thread)
375393

376394
def __len__(self):
377-
return len(self.dl)
395+
return len(self._dl)
378396

379397
def __next__(self):
380-
if len(self.fetch_threads) == 0:
398+
if len(self._fetch_threads) == 0:
381399
self._start_fetch_threads()
382400

383-
self.dl._complete_last_indices()
401+
self._dl._complete_last_indices()
384402

385403
data = None
386-
next_index = self.current + 1
404+
next_index = self._current + 1
387405
while data is None:
388-
if self.stopped:
406+
if self._stopped:
407+
if self._error is not None:
408+
raise self._error
389409
raise StopIteration
390410

391411
try:
392412
arrived_index, data = (
393413
# next index is in the arrived list
394-
next((i, data) for i, data in self.arrived if i == next_index)
395-
if self.in_order
414+
next((i, data) for i, data in self._arrived if i == next_index)
415+
if self._in_order
396416
# any index is in the arrived list
397-
else next((i, data) for i, data in self.arrived)
417+
else next((i, data) for i, data in self._arrived)
398418
)
399-
self.arrived = [a for a in self.arrived if a[0] != arrived_index]
400-
self.current = next_index
401-
next_index = self.current + 1
419+
self._arrived = [a for a in self._arrived if a[0] != arrived_index]
420+
self._current = next_index
421+
next_index = self._current + 1
402422
except StopIteration:
403423
# nothing is arrived yet
404424
continue
405425

406-
self.dl._set_last_indices(data)
426+
self._dl._set_last_indices(data)
407427
return data
408428

409429
def __iter__(self):
410-
return self
430+
try:
431+
while True:
432+
yield next(self)
433+
finally:
434+
self._stop()
411435

412436
def __getitem__(self, index: int):
413437
return next(self)
438+
439+
def __del__(self):
440+
self._stop()

0 commit comments

Comments
 (0)