Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions newsfragments/881.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Cache detector module slices in FormatNXmx to avoid redundant HDF5 reads
12 changes: 8 additions & 4 deletions src/dxtbx/format/FormatNXmx.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ def _start(self):
dxtbx.nexus.get_static_mask(nxdetector)
)
self._bit_depth_readout = nxdetector.bit_depth_readout
self._nxdata_cached = nxdata
self._nxdetector_cached = nxdetector
self._module_slices_cached = dxtbx.nexus.get_detector_module_slices(nxdetector)

if self._scan_model:
self._num_images = len(self._scan_model)
Expand Down Expand Up @@ -158,11 +161,12 @@ def get_static_mask(self, index=None, goniometer=None):
return self._static_mask

def get_raw_data(self, index):
nxmx_obj = self._get_nxmx(self._cached_file_handle)
nxdata = nxmx_obj.entries[0].data[0]
nxdetector = nxmx_obj.entries[0].instruments[0].detectors[0]
raw_data = dxtbx.nexus.get_raw_data(
nxdata, nxdetector, index, bit_depth=self._bit_depth_readout
self._nxdata_cached,
self._nxdetector_cached,
index,
bit_depth=self._bit_depth_readout,
module_slices=self._module_slices_cached,
)
if self._bit_depth_readout:
# if 32 bit then it is a signed int, I think if 8, 16 then it is
Expand Down
19 changes: 12 additions & 7 deletions src/dxtbx/format/FormatNXmxEigerFilewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,18 @@ def _get_nxmx(self, fh: h5py.File):
return nxmx_obj

def get_raw_data(self, index):
nxmx_obj = self._get_nxmx(self._cached_file_handle)
nxdata = nxmx_obj.entries[0].data[0]
nxdetector = nxmx_obj.entries[0].instruments[0].detectors[0]

# Prefer bit_depth_image over bit_depth_readout since the former
# actually corresponds to the bit depth of the images as stored on
# disk. See also:
# https://www.dectris.com/support/downloads/header-docs/nexus/
bit_depth = self._bit_depth_image or self._bit_depth_readout
raw_data = get_raw_data(nxdata, nxdetector, index, bit_depth)
raw_data = get_raw_data(
self._nxdata_cached,
self._nxdetector_cached,
index,
bit_depth,
module_slices=self._module_slices_cached,
)

if bit_depth:
# if 32 bit then it is a signed int, I think if 8, 16 then it is
Expand All @@ -122,6 +124,7 @@ def get_raw_data(
nxdetector: nxmx.NXdetector,
index: int,
bit_depth: int | None = None,
module_slices: tuple[tuple[slice, ...], ...] | None = None,
) -> tuple[flex.float | flex.double | flex.int, ...]:
"""Return the raw data for an NXdetector.

Expand All @@ -142,9 +145,11 @@ def get_raw_data(
raise IndexError(f"Out of range index for raw data {index}")
all_data = []
sliced_outer = data[index]
for module_slices in get_detector_module_slices(nxdetector):
if module_slices is None:
module_slices = get_detector_module_slices(nxdetector)
for slices in module_slices:
data_as_flex = _dataset_as_flex(
sliced_outer, tuple(module_slices), bit_depth=bit_depth
sliced_outer, tuple(slices), bit_depth=bit_depth
)
all_data.append(data_as_flex)
return tuple(all_data)
7 changes: 5 additions & 2 deletions src/dxtbx/nexus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,7 @@ def get_raw_data(
nxdetector: nxmx.NXdetector,
index: int,
bit_depth: int | None = None,
module_slices: tuple[tuple[slice, ...], ...] | None = None,
) -> tuple[flex.float | flex.double | flex.int, ...]:
"""Return the raw data for an NXdetector.

Expand All @@ -583,9 +584,11 @@ def get_raw_data(
data = list(nxdata.values())[0]
all_data = []
sliced_outer = data[index]
for module_slices in get_detector_module_slices(nxdetector):
if module_slices is None:
module_slices = get_detector_module_slices(nxdetector)
for slices in module_slices:
data_as_flex = _dataset_as_flex(
sliced_outer, tuple(module_slices), bit_depth=bit_depth
sliced_outer, tuple(slices), bit_depth=bit_depth
)
data_as_flex.reshape(
flex.grid(data_as_flex.all()[-2:])
Expand Down
Loading