Skip to content

Commit 04c297b

Browse files
authored
Further IO improvements (#2477)
1 parent 1e5114c commit 04c297b

File tree

2 files changed

+89
-59
lines changed

2 files changed

+89
-59
lines changed
Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,36 @@
11
import numpy as np
22
from glob import glob
3-
from lbann.util.data import Sample, SampleDims, Dataset, DistConvDataset
3+
from lbann.util.data import Sample, SampleDims, DistConvDataset
44
import h5py as h5
55
import os
66

7-
7+
88
class CosmoFlowDataset(DistConvDataset):
99
def __init__(self, data_dir, input_width, num_secrets):
1010
self.data_dir = data_dir
1111
self.input_width = input_width
1212
self.num_secrets = num_secrets
13-
self.samples = glob(os.path.join(data_dir, '*.hdf5'))
13+
self.samples = glob(os.path.join(data_dir, "*.hdf5"))
1414
self.samples.sort()
15-
15+
1616
def __len__(self):
1717
return len(self.samples)
18-
18+
1919
def __getitem__(self, index) -> Sample:
20-
data = h5.File(self.samples[index], 'r')
20+
data = h5.File(self.samples[index], "r")
2121
slice_width = self.input_width // self.num_io_partitions
2222
slice_ind = self.rank % self.num_io_partitions
23-
full = data['full'][:,
24-
slice_ind*slice_width:(slice_ind+1)*slice_width,
25-
:self.input_width,
26-
:self.input_width].astype(np.float32)
27-
par = data['unitPar'][:].astype(np.float32)
28-
return Sample(sample=np.ascontiguousarray(full), response=par)
29-
23+
full = data["full"][
24+
:,
25+
slice_ind * slice_width : (slice_ind + 1) * slice_width,
26+
: self.input_width,
27+
: self.input_width,
28+
]
29+
par = data["unitPar"][:].astype(np.float32)
30+
return Sample(sample=full, response=par)
31+
3032
def get_sample_dims(self):
31-
return SampleDims(sample=[4, self.input_width, self.input_width, self.input_width], response=self.num_secrets)
33+
return SampleDims(
34+
sample=[4, self.input_width, self.input_width, self.input_width],
35+
response=self.num_secrets,
36+
)

python/lbann/util/data.py

Lines changed: 70 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Dict, List, Optional, Union
1010
from numpy.typing import ArrayLike
1111
import concurrent.futures as cf
12+
from multiprocessing import resource_tracker
1213

1314

1415
class Sample:
@@ -183,11 +184,14 @@ def __init__(
183184
self.dataset = dataset
184185
self.num_procs = num_procs
185186
self.prefetch_factor = prefetch_factor
186-
self.dtype = dtype
187+
self.dtype = np.dtype(dtype)
187188
self.sample_dims = dataset.get_sample_dims()
188189
self.num_io_partitions = 1
189190
self.loaded_samples = []
190191
self.thread_pool = cf.ThreadPoolExecutor(max_workers=num_procs)
192+
self.shms = {}
193+
self.returned_shms = []
194+
self.batch = None
191195

192196
if isinstance(self.dataset, DistConvDataset):
193197
self.num_io_partitions = self.dataset.num_io_partitions
@@ -198,6 +202,19 @@ def __init__(
198202
initargs=(self.dataset,),
199203
)
200204

205+
self.shm_size = 0
206+
if hasattr(self.sample_dims, "sample"):
207+
self.sample_size = (
208+
np.prod(self.sample_dims.sample) // self.num_io_partitions
209+
)
210+
self.shm_size += self.sample_size
211+
if hasattr(self.sample_dims, "label"):
212+
self.label_size = np.prod(self.sample_dims.sample)
213+
self.shm_size += self.label_size
214+
if hasattr(self.sample_dims, "response"):
215+
self.response_size = self.sample_dims.response
216+
self.shm_size += self.response_size
217+
201218
@staticmethod
202219
def init_worker(dataset):
203220
"""
@@ -225,8 +242,12 @@ def terminate(self) -> None:
225242
"""
226243
self.pool.terminate()
227244

245+
for shm in self.shms.values():
246+
shm.close()
247+
shm.unlink()
248+
228249
@staticmethod
229-
def load_sample(ind) -> Sample:
250+
def load_sample(ind, shm_name, shm_size, dtype) -> Sample:
230251
"""
231252
Loads the sample from the dataset at the specified index.
232253
This function must be called from a worker process.
@@ -240,19 +261,10 @@ def load_sample(ind) -> Sample:
240261
"""
241262
samp = g_dataset[ind]
242263

243-
shm_size = 0
244-
dtype = None
245-
if hasattr(samp, "sample"):
246-
dtype = samp.sample.dtype
247-
shm_size += samp.sample.size
248-
if hasattr(samp, "label"):
249-
dtype = samp.label.dtype
250-
shm_size += samp.label.size
251-
if hasattr(samp, "response"):
252-
dtype = samp.response.dtype
253-
shm_size += samp.response.size
254-
255-
shm = SharedMemory(create=True, size=shm_size * dtype.itemsize)
264+
shm = SharedMemory(name=shm_name)
265+
resource_tracker.unregister(
266+
shm._name, "shared_memory"
267+
) # Prevent the resource tracker from interfering during process pool shutdown
256268
shm_arr = np.ndarray(shm_size, dtype=dtype, buffer=shm.buf)
257269

258270
offset = 0
@@ -270,14 +282,23 @@ def load_sample(ind) -> Sample:
270282
offset = new_offset
271283

272284
shm.close()
273-
return shm.name, shm_size
285+
return shm.name
274286

275287
def load_next_sample_async(self, ind: int):
276288
"""
277289
Submit the next sample index to be loaded to the worker pool.
278290
"""
291+
if not self.returned_shms:
292+
shm = SharedMemory(create=True, size=self.shm_size * self.dtype.itemsize)
293+
shm_name = shm.name
294+
self.shms[shm_name] = shm
295+
else:
296+
shm_name = self.returned_shms.pop()
297+
279298
self.loaded_samples.append(
280-
self.pool.apply_async(DataReader.load_sample, (ind,))
299+
self.pool.apply_async(
300+
DataReader.load_sample, (ind, shm_name, self.shm_size, self.dtype)
301+
)
281302
)
282303

283304
def queue_samples(self, inds: List[int]) -> None:
@@ -301,44 +322,45 @@ def get_batch(self, batch_size: int) -> Dict[str, Union[np.ndarray, int]]:
301322
:rtype: Dict[str, Union[np.ndarray, int]]
302323
"""
303324

304-
batch = {}
305-
if hasattr(self.sample_dims, "sample"):
306-
sample_size = np.prod(self.sample_dims.sample) // self.num_io_partitions
307-
batch["sample"] = np.empty([batch_size, sample_size], dtype=self.dtype)
308-
batch["sample_ptr"] = batch["sample"].ctypes.data
309-
if hasattr(self.sample_dims, "label"):
310-
label_size = np.prod(self.sample_dims.sample)
311-
batch["label"] = np.empty([batch_size, label_size], dtype=self.dtype)
312-
batch["label_ptr"] = batch["label"].ctypes.data
313-
if hasattr(self.sample_dims, "response"):
314-
response_size = self.sample_dims.response
315-
batch["response"] = np.empty([batch_size, response_size], dtype=self.dtype)
316-
batch["response_ptr"] = batch["response"].ctypes.data
325+
if self.batch is None:
326+
batch = {}
327+
if hasattr(self.sample_dims, "sample"):
328+
batch["sample"] = np.empty(
329+
[batch_size, self.sample_size], dtype=self.dtype
330+
)
331+
batch["sample_ptr"] = batch["sample"].ctypes.data
332+
if hasattr(self.sample_dims, "label"):
333+
batch["label"] = np.empty(
334+
[batch_size, self.label_size], dtype=self.dtype
335+
)
336+
batch["label_ptr"] = batch["label"].ctypes.data
337+
if hasattr(self.sample_dims, "response"):
338+
batch["response"] = np.empty(
339+
[batch_size, self.response_size], dtype=self.dtype
340+
)
341+
batch["response_ptr"] = batch["response"].ctypes.data
342+
self.batch = batch
317343

318344
def copy_to_array(i, sample):
319-
shm_name, shm_size = sample.get()
320-
321-
shm = SharedMemory(name=shm_name)
322-
shm_arr = np.ndarray(shm_size, dtype=self.dtype, buffer=shm.buf)
345+
shm_name = sample.get()
346+
shm = self.shms[shm_name]
347+
shm_arr = np.ndarray(self.shm_size, dtype=self.dtype, buffer=shm.buf)
323348

324349
offset = 0
325350
if hasattr(self.sample_dims, "sample"):
326-
new_offset = offset + sample_size
327-
batch["sample"][i, :] = shm_arr[offset:new_offset]
351+
new_offset = offset + self.sample_size
352+
self.batch["sample"][i, :] = shm_arr[offset:new_offset]
328353
offset = new_offset
329354
if hasattr(self.sample_dims, "label"):
330-
new_offset = offset + label_size
331-
batch["label"][i, :] = shm_arr[offset:new_offset]
355+
new_offset = offset + self.label_size
356+
self.batch["label"][i, :] = shm_arr[offset:new_offset]
332357
offset = new_offset
333358
if hasattr(self.sample_dims, "response"):
334-
new_offset = offset + response_size
335-
batch["response"][i, :] = shm_arr[offset:new_offset]
359+
new_offset = offset + self.response_size
360+
self.batch["response"][i, :] = shm_arr[offset:new_offset]
336361
offset = new_offset
337362

338-
del shm_arr
339-
340-
shm.close()
341-
shm.unlink()
363+
self.returned_shms.append(shm_name)
342364

343365
futures = []
344366
for i in range(batch_size):
@@ -347,8 +369,11 @@ def copy_to_array(i, sample):
347369
)
348370

349371
cf.wait(futures)
372+
# Check for any exceptions
373+
for f in futures:
374+
f.result()
350375

351-
return batch
376+
return self.batch
352377

353378

354379
def construct_python_dataset_reader(

0 commit comments

Comments
 (0)