99from typing import Dict , List , Optional , Union
1010from numpy .typing import ArrayLike
1111import concurrent .futures as cf
12+ from multiprocessing import resource_tracker
1213
1314
1415class Sample :
@@ -183,11 +184,14 @@ def __init__(
183184 self .dataset = dataset
184185 self .num_procs = num_procs
185186 self .prefetch_factor = prefetch_factor
186- self .dtype = dtype
187+ self .dtype = np . dtype ( dtype )
187188 self .sample_dims = dataset .get_sample_dims ()
188189 self .num_io_partitions = 1
189190 self .loaded_samples = []
190191 self .thread_pool = cf .ThreadPoolExecutor (max_workers = num_procs )
192+ self .shms = {}
193+ self .returned_shms = []
194+ self .batch = None
191195
192196 if isinstance (self .dataset , DistConvDataset ):
193197 self .num_io_partitions = self .dataset .num_io_partitions
@@ -198,6 +202,19 @@ def __init__(
198202 initargs = (self .dataset ,),
199203 )
200204
205+ self .shm_size = 0
206+ if hasattr (self .sample_dims , "sample" ):
207+ self .sample_size = (
208+ np .prod (self .sample_dims .sample ) // self .num_io_partitions
209+ )
210+ self .shm_size += self .sample_size
211+ if hasattr (self .sample_dims , "label" ):
212+ self .label_size = np .prod (self .sample_dims .sample )
213+ self .shm_size += self .label_size
214+ if hasattr (self .sample_dims , "response" ):
215+ self .response_size = self .sample_dims .response
216+ self .shm_size += self .response_size
217+
201218 @staticmethod
202219 def init_worker (dataset ):
203220 """
@@ -225,8 +242,12 @@ def terminate(self) -> None:
225242 """
226243 self .pool .terminate ()
227244
245+ for shm in self .shms .values ():
246+ shm .close ()
247+ shm .unlink ()
248+
228249 @staticmethod
229- def load_sample (ind ) -> Sample :
250+ def load_sample (ind , shm_name , shm_size , dtype ) -> Sample :
230251 """
231252 Loads the sample from the dataset at the specified index.
232253 This function must be called from a worker process.
@@ -240,19 +261,10 @@ def load_sample(ind) -> Sample:
240261 """
241262 samp = g_dataset [ind ]
242263
243- shm_size = 0
244- dtype = None
245- if hasattr (samp , "sample" ):
246- dtype = samp .sample .dtype
247- shm_size += samp .sample .size
248- if hasattr (samp , "label" ):
249- dtype = samp .label .dtype
250- shm_size += samp .label .size
251- if hasattr (samp , "response" ):
252- dtype = samp .response .dtype
253- shm_size += samp .response .size
254-
255- shm = SharedMemory (create = True , size = shm_size * dtype .itemsize )
264+ shm = SharedMemory (name = shm_name )
265+ resource_tracker .unregister (
266+ shm ._name , "shared_memory"
267+ ) # Prevent the resource tracker from interfering during process pool shutdown
256268 shm_arr = np .ndarray (shm_size , dtype = dtype , buffer = shm .buf )
257269
258270 offset = 0
@@ -270,14 +282,23 @@ def load_sample(ind) -> Sample:
270282 offset = new_offset
271283
272284 shm .close ()
273- return shm .name , shm_size
285+ return shm .name
274286
275287 def load_next_sample_async (self , ind : int ):
276288 """
277289 Submit the next sample index to be loaded to the worker pool.
278290 """
291+ if not self .returned_shms :
292+ shm = SharedMemory (create = True , size = self .shm_size * self .dtype .itemsize )
293+ shm_name = shm .name
294+ self .shms [shm_name ] = shm
295+ else :
296+ shm_name = self .returned_shms .pop ()
297+
279298 self .loaded_samples .append (
280- self .pool .apply_async (DataReader .load_sample , (ind ,))
299+ self .pool .apply_async (
300+ DataReader .load_sample , (ind , shm_name , self .shm_size , self .dtype )
301+ )
281302 )
282303
283304 def queue_samples (self , inds : List [int ]) -> None :
@@ -301,44 +322,45 @@ def get_batch(self, batch_size: int) -> Dict[str, Union[np.ndarray, int]]:
301322 :rtype: Dict[str, Union[np.ndarray, int]]
302323 """
303324
304- batch = {}
305- if hasattr (self .sample_dims , "sample" ):
306- sample_size = np .prod (self .sample_dims .sample ) // self .num_io_partitions
307- batch ["sample" ] = np .empty ([batch_size , sample_size ], dtype = self .dtype )
308- batch ["sample_ptr" ] = batch ["sample" ].ctypes .data
309- if hasattr (self .sample_dims , "label" ):
310- label_size = np .prod (self .sample_dims .sample )
311- batch ["label" ] = np .empty ([batch_size , label_size ], dtype = self .dtype )
312- batch ["label_ptr" ] = batch ["label" ].ctypes .data
313- if hasattr (self .sample_dims , "response" ):
314- response_size = self .sample_dims .response
315- batch ["response" ] = np .empty ([batch_size , response_size ], dtype = self .dtype )
316- batch ["response_ptr" ] = batch ["response" ].ctypes .data
325+ if self .batch is None :
326+ batch = {}
327+ if hasattr (self .sample_dims , "sample" ):
328+ batch ["sample" ] = np .empty (
329+ [batch_size , self .sample_size ], dtype = self .dtype
330+ )
331+ batch ["sample_ptr" ] = batch ["sample" ].ctypes .data
332+ if hasattr (self .sample_dims , "label" ):
333+ batch ["label" ] = np .empty (
334+ [batch_size , self .label_size ], dtype = self .dtype
335+ )
336+ batch ["label_ptr" ] = batch ["label" ].ctypes .data
337+ if hasattr (self .sample_dims , "response" ):
338+ batch ["response" ] = np .empty (
339+ [batch_size , self .response_size ], dtype = self .dtype
340+ )
341+ batch ["response_ptr" ] = batch ["response" ].ctypes .data
342+ self .batch = batch
317343
318344 def copy_to_array (i , sample ):
319- shm_name , shm_size = sample .get ()
320-
321- shm = SharedMemory (name = shm_name )
322- shm_arr = np .ndarray (shm_size , dtype = self .dtype , buffer = shm .buf )
345+ shm_name = sample .get ()
346+ shm = self .shms [shm_name ]
347+ shm_arr = np .ndarray (self .shm_size , dtype = self .dtype , buffer = shm .buf )
323348
324349 offset = 0
325350 if hasattr (self .sample_dims , "sample" ):
326- new_offset = offset + sample_size
327- batch ["sample" ][i , :] = shm_arr [offset :new_offset ]
351+ new_offset = offset + self . sample_size
352+ self . batch ["sample" ][i , :] = shm_arr [offset :new_offset ]
328353 offset = new_offset
329354 if hasattr (self .sample_dims , "label" ):
330- new_offset = offset + label_size
331- batch ["label" ][i , :] = shm_arr [offset :new_offset ]
355+ new_offset = offset + self . label_size
356+ self . batch ["label" ][i , :] = shm_arr [offset :new_offset ]
332357 offset = new_offset
333358 if hasattr (self .sample_dims , "response" ):
334- new_offset = offset + response_size
335- batch ["response" ][i , :] = shm_arr [offset :new_offset ]
359+ new_offset = offset + self . response_size
360+ self . batch ["response" ][i , :] = shm_arr [offset :new_offset ]
336361 offset = new_offset
337362
338- del shm_arr
339-
340- shm .close ()
341- shm .unlink ()
363+ self .returned_shms .append (shm_name )
342364
343365 futures = []
344366 for i in range (batch_size ):
@@ -347,8 +369,11 @@ def copy_to_array(i, sample):
347369 )
348370
349371 cf .wait (futures )
372+ # Check for any exceptions
373+ for f in futures :
374+ f .result ()
350375
351- return batch
376+ return self . batch
352377
353378
354379def construct_python_dataset_reader (
0 commit comments