Skip to content

Commit 113a0ab

Browse files
committed
Use shared memory and copy out with threads
1 parent cb07681 commit 113a0ab

File tree

1 file changed

+87
-29
lines changed

1 file changed

+87
-29
lines changed

python/lbann/util/data.py

Lines changed: 87 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
import pickle
55
import lbann
66
from multiprocessing import Pool
7+
from multiprocessing.shared_memory import SharedMemory
78
import numpy as np
89
from typing import Dict, List, Optional, Union
910
from numpy.typing import ArrayLike
11+
import concurrent.futures as cf
1012

1113

1214
class Sample:
@@ -161,8 +163,9 @@ class DataReader:
161163
Helper class used by LBANN to control worker processes and handle sample/batch loading.
162164
"""
163165

164-
def __init__(self, dataset: Dataset, num_procs: int, prefetch_factor: int,
165-
dtype: str) -> None:
166+
def __init__(
167+
self, dataset: Dataset, num_procs: int, prefetch_factor: int, dtype: str
168+
) -> None:
166169
"""
167170
DataReader Constructor
168171
@@ -184,13 +187,16 @@ def __init__(self, dataset: Dataset, num_procs: int, prefetch_factor: int,
184187
self.sample_dims = dataset.get_sample_dims()
185188
self.num_io_partitions = 1
186189
self.loaded_samples = []
190+
self.thread_pool = cf.ThreadPoolExecutor(max_workers=num_procs)
187191

188192
if isinstance(self.dataset, DistConvDataset):
189193
self.num_io_partitions = self.dataset.num_io_partitions
190194

191-
self.pool = Pool(processes=num_procs,
192-
initializer=DataReader.init_worker,
193-
initargs=(self.dataset, ))
195+
self.pool = Pool(
196+
processes=num_procs,
197+
initializer=DataReader.init_worker,
198+
initargs=(self.dataset,),
199+
)
194200

195201
@staticmethod
196202
def init_worker(dataset):
@@ -232,14 +238,47 @@ def load_sample(ind) -> Sample:
232238
:return: Sample
233239
:rtype: Sample
234240
"""
235-
return g_dataset[ind]
241+
samp = g_dataset[ind]
242+
243+
shm_size = 0
244+
dtype = None
245+
if hasattr(samp, "sample"):
246+
dtype = samp.sample.dtype
247+
shm_size += samp.sample.size
248+
if hasattr(samp, "label"):
249+
dtype = samp.label.dtype
250+
shm_size += samp.label.size
251+
if hasattr(samp, "response"):
252+
dtype = samp.response.dtype
253+
shm_size += samp.response.size
254+
255+
shm = SharedMemory(create=True, size=shm_size * dtype.itemsize)
256+
shm_arr = np.ndarray(shm_size, dtype=dtype, buffer=shm.buf)
257+
258+
offset = 0
259+
if hasattr(samp, "sample"):
260+
new_offset = offset + samp.sample.size
261+
shm_arr[offset:new_offset] = samp.sample.ravel()
262+
offset = new_offset
263+
if hasattr(samp, "label"):
264+
new_offset = offset + samp.label.size
265+
shm_arr[offset:new_offset] = samp.label.ravel()
266+
offset = new_offset
267+
if hasattr(samp, "response"):
268+
new_offset = offset + samp.response.size
269+
shm_arr[offset:new_offset] = samp.response.ravel()
270+
offset = new_offset
271+
272+
shm.close()
273+
return shm.name, shm_size
236274

237275
def load_next_sample_async(self, ind: int):
238276
"""
239277
Submit the next sample index to be loaded to the worker pool.
240278
"""
241279
self.loaded_samples.append(
242-
self.pool.apply_async(DataReader.load_sample, (ind, )))
280+
self.pool.apply_async(DataReader.load_sample, (ind,))
281+
)
243282

244283
def queue_samples(self, inds: List[int]) -> None:
245284
"""
@@ -261,34 +300,53 @@ def get_batch(self, batch_size: int) -> Dict[str, Union[np.ndarray, int]]:
261300
:return: Batch of samples and pointers for each input field
262301
:rtype: Dict[str, Union[np.ndarray, int]]
263302
"""
264-
samples = []
265-
for _ in range(batch_size):
266-
samples.append(self.loaded_samples.pop(0).get())
267303

268304
batch = {}
269-
270-
# Note: we return the arrays with the pointers so that they aren't
271-
# deallocated by the garbage collector.
272-
batch["sample"] = np.ascontiguousarray([s.sample for s in samples],
273-
dtype=self.dtype)
274-
batch["sample_ptr"] = batch["sample"].ctypes.data
275-
assert (batch["sample"].size == np.prod(self.sample_dims.sample) *
276-
batch_size / self.num_io_partitions)
277-
305+
if hasattr(self.sample_dims, "sample"):
306+
sample_size = np.prod(self.sample_dims.sample) // self.num_io_partitions
307+
batch["sample"] = np.empty([batch_size, sample_size], dtype=self.dtype)
308+
batch["sample_ptr"] = batch["sample"].ctypes.data
278309
if hasattr(self.sample_dims, "label"):
279-
batch["label"] = np.ascontiguousarray([s.label for s in samples],
280-
dtype=self.dtype)
310+
label_size = np.prod(self.sample_dims.sample)
311+
batch["label"] = np.empty([batch_size, label_size], dtype=self.dtype)
281312
batch["label_ptr"] = batch["label"].ctypes.data
282-
assert batch["label"].size == np.prod(
283-
self.sample_dims.label) * batch_size
284-
285313
if hasattr(self.sample_dims, "response"):
286-
batch["response"] = np.ascontiguousarray(
287-
[s.response for s in samples], dtype=self.dtype)
314+
response_size = self.sample_dims.response
315+
batch["response"] = np.empty([batch_size, response_size], dtype=self.dtype)
288316
batch["response_ptr"] = batch["response"].ctypes.data
289-
assert (
290-
batch["response"].size == np.prod(self.sample_dims.response) *
291-
batch_size)
317+
318+
def copy_to_array(i, sample):
319+
shm_name, shm_size = sample.get()
320+
321+
shm = SharedMemory(name=shm_name)
322+
shm_arr = np.ndarray(shm_size, dtype=self.dtype, buffer=shm.buf)
323+
324+
offset = 0
325+
if hasattr(self.sample_dims, "sample"):
326+
new_offset = offset + sample_size
327+
batch["sample"][i, :] = shm_arr[offset:new_offset]
328+
offset = new_offset
329+
if hasattr(self.sample_dims, "label"):
330+
new_offset = offset + label_size
331+
batch["label"][i, :] = shm_arr[offset:new_offset]
332+
offset = new_offset
333+
if hasattr(self.sample_dims, "response"):
334+
new_offset = offset + response_size
335+
batch["response"][i, :] = shm_arr[offset:new_offset]
336+
offset = new_offset
337+
338+
del shm_arr
339+
340+
shm.close()
341+
shm.unlink()
342+
343+
futures = []
344+
for i in range(batch_size):
345+
futures.append(
346+
self.thread_pool.submit(copy_to_array, i, self.loaded_samples.pop(0))
347+
)
348+
349+
cf.wait(futures)
292350

293351
return batch
294352

0 commit comments

Comments
 (0)