44import pickle
55import lbann
66from multiprocessing import Pool
7+ from multiprocessing .shared_memory import SharedMemory
78import numpy as np
89from typing import Dict , List , Optional , Union
910from numpy .typing import ArrayLike
11+ import concurrent .futures as cf
1012
1113
1214class Sample :
@@ -161,8 +163,9 @@ class DataReader:
161163 Helper class used by LBANN to control worker processes and handle sample/batch loading.
162164 """
163165
164- def __init__ (self , dataset : Dataset , num_procs : int , prefetch_factor : int ,
165- dtype : str ) -> None :
166+ def __init__ (
167+ self , dataset : Dataset , num_procs : int , prefetch_factor : int , dtype : str
168+ ) -> None :
166169 """
167170 DataReader Constructor
168171
@@ -184,13 +187,16 @@ def __init__(self, dataset: Dataset, num_procs: int, prefetch_factor: int,
184187 self .sample_dims = dataset .get_sample_dims ()
185188 self .num_io_partitions = 1
186189 self .loaded_samples = []
190+ self .thread_pool = cf .ThreadPoolExecutor (max_workers = num_procs )
187191
188192 if isinstance (self .dataset , DistConvDataset ):
189193 self .num_io_partitions = self .dataset .num_io_partitions
190194
191- self .pool = Pool (processes = num_procs ,
192- initializer = DataReader .init_worker ,
193- initargs = (self .dataset , ))
195+ self .pool = Pool (
196+ processes = num_procs ,
197+ initializer = DataReader .init_worker ,
198+ initargs = (self .dataset ,),
199+ )
194200
195201 @staticmethod
196202 def init_worker (dataset ):
@@ -232,14 +238,47 @@ def load_sample(ind) -> Sample:
232238 :return: Sample
233239 :rtype: Sample
234240 """
235- return g_dataset [ind ]
241+ samp = g_dataset [ind ]
242+
243+ shm_size = 0
244+ dtype = None
245+ if hasattr (samp , "sample" ):
246+ dtype = samp .sample .dtype
247+ shm_size += samp .sample .size
248+ if hasattr (samp , "label" ):
249+ dtype = samp .label .dtype
250+ shm_size += samp .label .size
251+ if hasattr (samp , "response" ):
252+ dtype = samp .response .dtype
253+ shm_size += samp .response .size
254+
255+ shm = SharedMemory (create = True , size = shm_size * dtype .itemsize )
256+ shm_arr = np .ndarray (shm_size , dtype = dtype , buffer = shm .buf )
257+
258+ offset = 0
259+ if hasattr (samp , "sample" ):
260+ new_offset = offset + samp .sample .size
261+ shm_arr [offset :new_offset ] = samp .sample .ravel ()
262+ offset = new_offset
263+ if hasattr (samp , "label" ):
264+ new_offset = offset + samp .label .size
265+ shm_arr [offset :new_offset ] = samp .label .ravel ()
266+ offset = new_offset
267+ if hasattr (samp , "response" ):
268+ new_offset = offset + samp .response .size
269+ shm_arr [offset :new_offset ] = samp .response .ravel ()
270+ offset = new_offset
271+
272+ shm .close ()
273+ return shm .name , shm_size
236274
237275 def load_next_sample_async (self , ind : int ):
238276 """
239277 Submit the next sample index to be loaded to the worker pool.
240278 """
241279 self .loaded_samples .append (
242- self .pool .apply_async (DataReader .load_sample , (ind , )))
280+ self .pool .apply_async (DataReader .load_sample , (ind ,))
281+ )
243282
244283 def queue_samples (self , inds : List [int ]) -> None :
245284 """
@@ -261,34 +300,53 @@ def get_batch(self, batch_size: int) -> Dict[str, Union[np.ndarray, int]]:
261300 :return: Batch of samples and pointers for each input field
262301 :rtype: Dict[str, Union[np.ndarray, int]]
263302 """
264- samples = []
265- for _ in range (batch_size ):
266- samples .append (self .loaded_samples .pop (0 ).get ())
267303
268304 batch = {}
269-
270- # Note: we return the arrays with the pointers so that they aren't
271- # deallocated by the garbage collector.
272- batch ["sample" ] = np .ascontiguousarray ([s .sample for s in samples ],
273- dtype = self .dtype )
274- batch ["sample_ptr" ] = batch ["sample" ].ctypes .data
275- assert (batch ["sample" ].size == np .prod (self .sample_dims .sample ) *
276- batch_size / self .num_io_partitions )
277-
305+ if hasattr (self .sample_dims , "sample" ):
306+ sample_size = np .prod (self .sample_dims .sample ) // self .num_io_partitions
307+ batch ["sample" ] = np .empty ([batch_size , sample_size ], dtype = self .dtype )
308+ batch ["sample_ptr" ] = batch ["sample" ].ctypes .data
278309 if hasattr (self .sample_dims , "label" ):
279- batch [ "label" ] = np .ascontiguousarray ([ s . label for s in samples ],
280- dtype = self .dtype )
310+ label_size = np .prod ( self . sample_dims . sample )
311+ batch [ "label" ] = np . empty ([ batch_size , label_size ], dtype = self .dtype )
281312 batch ["label_ptr" ] = batch ["label" ].ctypes .data
282- assert batch ["label" ].size == np .prod (
283- self .sample_dims .label ) * batch_size
284-
285313 if hasattr (self .sample_dims , "response" ):
286- batch [ "response" ] = np . ascontiguousarray (
287- [ s . response for s in samples ], dtype = self .dtype )
314+ response_size = self . sample_dims . response
315+ batch [ " response" ] = np . empty ([ batch_size , response_size ], dtype = self .dtype )
288316 batch ["response_ptr" ] = batch ["response" ].ctypes .data
289- assert (
290- batch ["response" ].size == np .prod (self .sample_dims .response ) *
291- batch_size )
317+
318+ def copy_to_array (i , sample ):
319+ shm_name , shm_size = sample .get ()
320+
321+ shm = SharedMemory (name = shm_name )
322+ shm_arr = np .ndarray (shm_size , dtype = self .dtype , buffer = shm .buf )
323+
324+ offset = 0
325+ if hasattr (self .sample_dims , "sample" ):
326+ new_offset = offset + sample_size
327+ batch ["sample" ][i , :] = shm_arr [offset :new_offset ]
328+ offset = new_offset
329+ if hasattr (self .sample_dims , "label" ):
330+ new_offset = offset + label_size
331+ batch ["label" ][i , :] = shm_arr [offset :new_offset ]
332+ offset = new_offset
333+ if hasattr (self .sample_dims , "response" ):
334+ new_offset = offset + response_size
335+ batch ["response" ][i , :] = shm_arr [offset :new_offset ]
336+ offset = new_offset
337+
338+ del shm_arr
339+
340+ shm .close ()
341+ shm .unlink ()
342+
343+ futures = []
344+ for i in range (batch_size ):
345+ futures .append (
346+ self .thread_pool .submit (copy_to_array , i , self .loaded_samples .pop (0 ))
347+ )
348+
349+ cf .wait (futures )
292350
293351 return batch
294352
0 commit comments