@@ -143,10 +143,15 @@ def __init__(
143143 self ._api = _api (self ._api_url , self ._api_key )
144144
145145 self ._bytes = 0
146- self ._complete_thread : Optional [ threading . Thread ] = None
146+ self ._stopped = False
147147
148148 self .id = self ._iteration_id
149149
150+ self ._complete_thread = threading .Thread (
151+ target = self ._keep_complete_indices , daemon = True
152+ )
153+ self ._complete_thread .start ()
154+
150155 def torch (
151156 self ,
152157 pin_memory : bool = False ,
@@ -206,7 +211,7 @@ def __len__(self):
206211 return self ._total
207212
208213 def _keep_complete_indices (self ):
209- while True :
214+ while not self . _stopped :
210215 if len (self ._completed_indices ) == 0 :
211216 time .sleep (0.1 )
212217 continue
@@ -215,12 +220,6 @@ def _keep_complete_indices(self):
215220 self .complete (index )
216221
217222 def _complete_last_indices (self ):
218- if self ._complete_thread is None :
219- self ._complete_thread = threading .Thread (
220- target = self ._keep_complete_indices , daemon = True
221- )
222- self ._complete_thread .start ()
223-
224223 self ._completed_indices .update (self ._last_indices )
225224 self ._last_indices = set ()
226225
@@ -278,6 +277,10 @@ def _get_submitted_result(self, cache_key: str):
278277 except DeserializeException as e :
279278 raise ValueError (f"Failed to deserialize sample: { e } " )
280279
280+ def _stop (self ):
281+ self ._stopped = True
282+ self ._complete_thread .join ()
283+
281284 def __next__ (self ):
282285 self ._complete_last_indices ()
283286
@@ -295,11 +298,18 @@ def __next__(self):
295298 return sample_or_batch
296299
297300 def __iter__ (self ):
298- return self
301+ try :
302+ while True :
303+ yield next (self )
304+ finally :
305+ self ._stop ()
299306
300307 def __getitem__ (self , index : int ):
301308 return next (self )
302309
310+ def __del__ (self ):
311+ self ._stop ()
312+
303313
304314class AsyncLavenderDataLoader :
305315 def __init__ (
@@ -312,102 +322,119 @@ def __init__(
312322 if prefetch_factor < 1 :
313323 raise ValueError ("prefetch_factor must be greater than 0" )
314324
315- self .dl = dl
316- self .prefetch_factor = prefetch_factor
317- self .poll_interval = poll_interval
318- self .in_order = in_order
319- self .arrived : list [tuple [int , dict ]] = []
320- self .current = - 1
321- self .stopped = False
322- self .fetch_threads : list [threading .Thread ] = []
325+ self ._dl = dl
326+ self ._prefetch_factor = prefetch_factor
327+ self ._poll_interval = poll_interval
328+ self ._in_order = in_order
329+ self ._arrived : list [tuple [int , dict ]] = []
330+ self ._current = - 1
331+ self ._stopped = False
332+ self ._fetch_threads : list [threading .Thread ] = []
333+ self ._error : Optional [Exception ] = None
334+
335+ def _stop (self ):
336+ self ._stopped = True
337+ for thread in self ._fetch_threads :
338+ thread .join ()
339+ self ._dl ._stop ()
323340
324341 def _get_submitted_result (self , cache_key : str ):
325342 while True :
326- data = self .dl ._get_submitted_result (cache_key )
343+ data = self ._dl ._get_submitted_result (cache_key )
327344 if data is not None :
328345 return data
329346 else :
330- time .sleep (self .poll_interval )
347+ time .sleep (self ._poll_interval )
331348
332349 def _fetch_one (self ):
333350 try :
334- cache_key = self .dl ._submit_next_item ()
351+ cache_key = self ._dl ._submit_next_item ()
335352 except LavenderDataApiError as e :
336353 if "No more indices to pop" in str (e ):
337- self .stopped = True
354+ self ._stopped = True
338355 return
339356 else :
340357 raise e
341358
342359 data = None
343360 arrived_index = None
344- while arrived_index is None :
345- time .sleep (self .poll_interval )
361+ while arrived_index is None and not self . _stopped :
362+ time .sleep (self ._poll_interval )
346363
347364 try :
348- data = self .dl ._get_submitted_result (cache_key )
365+ data = self ._dl ._get_submitted_result (cache_key )
349366 if data is not None :
350367 arrived_index = data ["_lavender_data_current" ]
351368 except StopIteration :
352- self .stopped = True
369+ self ._stopped = True
353370 return
354371 except LavenderDataSampleProcessingError as e :
355- if self .dl ._skip_on_failure :
372+ if self ._dl ._skip_on_failure :
356373 arrived_index = e .current
357374 else :
358375 raise e
359376
360377 return arrived_index , data
361378
362379 def _keep_fetching (self ):
363- while not self .stopped :
380+ while not self ._stopped :
364381 try :
365382 arrived_index , data = self ._fetch_one ()
366- self .arrived .append ((arrived_index , data ))
383+ self ._arrived .append ((arrived_index , data ))
367384 except Exception as e :
368- warnings .warn (f"Error in fetch thread: { e } " )
385+ self ._error = e
386+ self ._stopped = True
369387
370388 def _start_fetch_threads (self ):
371- for _ in range (self .prefetch_factor ):
372- thread = threading .Thread (target = self ._keep_fetching , daemon = True )
389+ for _ in range (self ._prefetch_factor ):
390+ thread = threading .Thread (target = self ._keep_fetching )
373391 thread .start ()
374- self .fetch_threads .append (thread )
392+ self ._fetch_threads .append (thread )
375393
376394 def __len__ (self ):
377- return len (self .dl )
395+ return len (self ._dl )
378396
379397 def __next__ (self ):
380- if len (self .fetch_threads ) == 0 :
398+ if len (self ._fetch_threads ) == 0 :
381399 self ._start_fetch_threads ()
382400
383- self .dl ._complete_last_indices ()
401+ self ._dl ._complete_last_indices ()
384402
385403 data = None
386- next_index = self .current + 1
404+ next_index = self ._current + 1
387405 while data is None :
388- if self .stopped :
406+ if self ._stopped :
407+ if self ._error is not None :
408+ raise self ._error
389409 raise StopIteration
390410
391411 try :
392412 arrived_index , data = (
393413 # next index is in the arrived list
394- next ((i , data ) for i , data in self .arrived if i == next_index )
395- if self .in_order
414+ next ((i , data ) for i , data in self ._arrived if i == next_index )
415+ if self ._in_order
396416 # any index is in the arrived list
397- else next ((i , data ) for i , data in self .arrived )
417+ else next ((i , data ) for i , data in self ._arrived )
398418 )
399- self .arrived = [a for a in self .arrived if a [0 ] != arrived_index ]
400- self .current = next_index
401- next_index = self .current + 1
419+ self ._arrived = [a for a in self ._arrived if a [0 ] != arrived_index ]
420+ self ._current = next_index
421+ next_index = self ._current + 1
402422 except StopIteration :
403423 # nothing is arrived yet
404424 continue
405425
406- self .dl ._set_last_indices (data )
426+ self ._dl ._set_last_indices (data )
407427 return data
408428
409429 def __iter__ (self ):
410- return self
430+ try :
431+ while True :
432+ yield next (self )
433+ finally :
434+ self ._stop ()
411435
412436 def __getitem__ (self , index : int ):
413437 return next (self )
438+
439+ def __del__ (self ):
440+ self ._stop ()
0 commit comments