Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions applications/physics/cosmology/cosmoflow/cosmoflow_dataset.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,36 @@
import numpy as np
from glob import glob
from lbann.util.data import Sample, SampleDims, Dataset, DistConvDataset
from lbann.util.data import Sample, SampleDims, DistConvDataset
import h5py as h5
import os


class CosmoFlowDataset(DistConvDataset):
def __init__(self, data_dir, input_width, num_secrets):
self.data_dir = data_dir
self.input_width = input_width
self.num_secrets = num_secrets
self.samples = glob(os.path.join(data_dir, '*.hdf5'))
self.samples = glob(os.path.join(data_dir, "*.hdf5"))
self.samples.sort()

def __len__(self):
return len(self.samples)

def __getitem__(self, index) -> Sample:
data = h5.File(self.samples[index], 'r')
data = h5.File(self.samples[index], "r")
slice_width = self.input_width // self.num_io_partitions
slice_ind = self.rank % self.num_io_partitions
full = data['full'][:,
slice_ind*slice_width:(slice_ind+1)*slice_width,
:self.input_width,
:self.input_width].astype(np.float32)
par = data['unitPar'][:].astype(np.float32)
return Sample(sample=np.ascontiguousarray(full), response=par)

full = data["full"][
:,
slice_ind * slice_width : (slice_ind + 1) * slice_width,
: self.input_width,
: self.input_width,
]
par = data["unitPar"][:].astype(np.float32)
return Sample(sample=full, response=par)

def get_sample_dims(self):
return SampleDims(sample=[4, self.input_width, self.input_width, self.input_width], response=self.num_secrets)
return SampleDims(
sample=[4, self.input_width, self.input_width, self.input_width],
response=self.num_secrets,
)
115 changes: 70 additions & 45 deletions python/lbann/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Dict, List, Optional, Union
from numpy.typing import ArrayLike
import concurrent.futures as cf
from multiprocessing import resource_tracker


class Sample:
Expand Down Expand Up @@ -183,11 +184,14 @@ def __init__(
self.dataset = dataset
self.num_procs = num_procs
self.prefetch_factor = prefetch_factor
self.dtype = dtype
self.dtype = np.dtype(dtype)
self.sample_dims = dataset.get_sample_dims()
self.num_io_partitions = 1
self.loaded_samples = []
self.thread_pool = cf.ThreadPoolExecutor(max_workers=num_procs)
self.shms = {}
self.returned_shms = []
self.batch = None

if isinstance(self.dataset, DistConvDataset):
self.num_io_partitions = self.dataset.num_io_partitions
Expand All @@ -198,6 +202,19 @@ def __init__(
initargs=(self.dataset,),
)

self.shm_size = 0
if hasattr(self.sample_dims, "sample"):
self.sample_size = (
np.prod(self.sample_dims.sample) // self.num_io_partitions
)
self.shm_size += self.sample_size
if hasattr(self.sample_dims, "label"):
self.label_size = np.prod(self.sample_dims.sample)
self.shm_size += self.label_size
if hasattr(self.sample_dims, "response"):
self.response_size = self.sample_dims.response
self.shm_size += self.response_size

@staticmethod
def init_worker(dataset):
"""
Expand Down Expand Up @@ -225,8 +242,12 @@ def terminate(self) -> None:
"""
self.pool.terminate()

for shm in self.shms.values():
shm.close()
shm.unlink()

@staticmethod
def load_sample(ind) -> Sample:
def load_sample(ind, shm_name, shm_size, dtype) -> Sample:
"""
Loads the sample from the dataset at the specified index.
This function must be called from a worker process.
Expand All @@ -240,19 +261,10 @@ def load_sample(ind) -> Sample:
"""
samp = g_dataset[ind]

shm_size = 0
dtype = None
if hasattr(samp, "sample"):
dtype = samp.sample.dtype
shm_size += samp.sample.size
if hasattr(samp, "label"):
dtype = samp.label.dtype
shm_size += samp.label.size
if hasattr(samp, "response"):
dtype = samp.response.dtype
shm_size += samp.response.size

shm = SharedMemory(create=True, size=shm_size * dtype.itemsize)
shm = SharedMemory(name=shm_name)
resource_tracker.unregister(
shm._name, "shared_memory"
) # Prevent the resource tracker from interfering during process pool shutdown
shm_arr = np.ndarray(shm_size, dtype=dtype, buffer=shm.buf)

offset = 0
Expand All @@ -270,14 +282,23 @@ def load_sample(ind) -> Sample:
offset = new_offset

shm.close()
return shm.name, shm_size
return shm.name

def load_next_sample_async(self, ind: int):
"""
Submit the next sample index to be loaded to the worker pool.
"""
if not self.returned_shms:
shm = SharedMemory(create=True, size=self.shm_size * self.dtype.itemsize)
shm_name = shm.name
self.shms[shm_name] = shm
else:
shm_name = self.returned_shms.pop()

self.loaded_samples.append(
self.pool.apply_async(DataReader.load_sample, (ind,))
self.pool.apply_async(
DataReader.load_sample, (ind, shm_name, self.shm_size, self.dtype)
)
)

def queue_samples(self, inds: List[int]) -> None:
Expand All @@ -301,44 +322,45 @@ def get_batch(self, batch_size: int) -> Dict[str, Union[np.ndarray, int]]:
:rtype: Dict[str, Union[np.ndarray, int]]
"""

batch = {}
if hasattr(self.sample_dims, "sample"):
sample_size = np.prod(self.sample_dims.sample) // self.num_io_partitions
batch["sample"] = np.empty([batch_size, sample_size], dtype=self.dtype)
batch["sample_ptr"] = batch["sample"].ctypes.data
if hasattr(self.sample_dims, "label"):
label_size = np.prod(self.sample_dims.sample)
batch["label"] = np.empty([batch_size, label_size], dtype=self.dtype)
batch["label_ptr"] = batch["label"].ctypes.data
if hasattr(self.sample_dims, "response"):
response_size = self.sample_dims.response
batch["response"] = np.empty([batch_size, response_size], dtype=self.dtype)
batch["response_ptr"] = batch["response"].ctypes.data
if self.batch is None:
batch = {}
if hasattr(self.sample_dims, "sample"):
batch["sample"] = np.empty(
[batch_size, self.sample_size], dtype=self.dtype
)
batch["sample_ptr"] = batch["sample"].ctypes.data
if hasattr(self.sample_dims, "label"):
batch["label"] = np.empty(
[batch_size, self.label_size], dtype=self.dtype
)
batch["label_ptr"] = batch["label"].ctypes.data
if hasattr(self.sample_dims, "response"):
batch["response"] = np.empty(
[batch_size, self.response_size], dtype=self.dtype
)
batch["response_ptr"] = batch["response"].ctypes.data
self.batch = batch

def copy_to_array(i, sample):
shm_name, shm_size = sample.get()

shm = SharedMemory(name=shm_name)
shm_arr = np.ndarray(shm_size, dtype=self.dtype, buffer=shm.buf)
shm_name = sample.get()
shm = self.shms[shm_name]
shm_arr = np.ndarray(self.shm_size, dtype=self.dtype, buffer=shm.buf)

offset = 0
if hasattr(self.sample_dims, "sample"):
new_offset = offset + sample_size
batch["sample"][i, :] = shm_arr[offset:new_offset]
new_offset = offset + self.sample_size
self.batch["sample"][i, :] = shm_arr[offset:new_offset]
offset = new_offset
if hasattr(self.sample_dims, "label"):
new_offset = offset + label_size
batch["label"][i, :] = shm_arr[offset:new_offset]
new_offset = offset + self.label_size
self.batch["label"][i, :] = shm_arr[offset:new_offset]
offset = new_offset
if hasattr(self.sample_dims, "response"):
new_offset = offset + response_size
batch["response"][i, :] = shm_arr[offset:new_offset]
new_offset = offset + self.response_size
self.batch["response"][i, :] = shm_arr[offset:new_offset]
offset = new_offset

del shm_arr

shm.close()
shm.unlink()
self.returned_shms.append(shm_name)

futures = []
for i in range(batch_size):
Expand All @@ -347,8 +369,11 @@ def copy_to_array(i, sample):
)

cf.wait(futures)
# Check for any exceptions
for f in futures:
f.result()

return batch
return self.batch


def construct_python_dataset_reader(
Expand Down
Loading