Skip to content

Adds h5py dump #221

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
182 changes: 182 additions & 0 deletions lcls_tools/common/data/saver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
from typing import Dict, Any
import h5py
import numpy as np
from pydantic import validate_call
import pandas as pd


class H5Saver:
"""
Class to dump and load dictionaries to and from HDF5 files.

Methods
-------
dump(data, filepath)
Dumps a dictionary to an HDF5 file.
load(filepath)
Loads a dictionary from an HDF5 file.
"""

def __init__(self):
"""
Initialize the H5Saver class.

Parameters
----------
string_dtype : str, optional
The encoding to use when saving string data. Default is 'utf-8'.
"""
self.string_dtype = "utf-8"
self.supported_types = (bool, int, float, np.integer, np.floating)

@validate_call
def dump(self, data: Dict[str, Any], filepath: str):
"""
Save a dictionary to an HDF5 file. 5s

Parameters
----------
data : Dict[str, Any]
The dictionary to save.
filepath : str
The path to save the HDF5 file.

Returns
-------
None
"""
dt = h5py.string_dtype(encoding=self.string_dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you rename this to something like h5str? dt is not descriptive of what this does.


def recursive_save(d, f):
for key, val in d.items():
if key == "attrs":
f.attrs.update(val or h5py.Empty("f4"))
elif isinstance(val, dict):
group = f.create_group(key, track_order=True)
recursive_save(val, group)
elif isinstance(val, list):
if all(isinstance(ele, self.supported_types) for ele in val):
f.create_dataset(key, data=val, track_order=True)
elif all(isinstance(ele, np.ndarray) for ele in val):
# save np.arrays as datasets
for i, ele in enumerate(val):
f.create_dataset(f"{key}/{i}", data=ele, track_order=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to examine what this code is trying to accomplish by looking at the load function later down the line. load does three things: it converts scalar datasets into key/value pairs in a parent dictionary, it converts groups with the pandas attribute into pandas dataframes, and it converts all other groups into dictionaries.

Here's a pretty basic case:

list_of_dicts = {'list': [{'foo': 'a'}, {'bar': 'b'}]}
saver.dump(list_of_dicts, filepath)
result = saver.load(filepath)
# result ={'list': {'0': {'foo': 'a'}, '1': {'bar': 'b'}}}

The result is that the original list_of_dicts is not recovered. This isn't so bad, and we could reasonably recover the list as long as we know what's happening. But looking ahead, there are two cases in which the code saves data as strings, which means that data in a nested or mixed data type list can't be recovered easily.

There's a couple approaches we could take to resolving this. We could create groups that contain information about the source class. This shouldn't be so hard to implement, but it would mean writing a lot more code. The second approach is to reduce the number of cases that this code handles to the ones that you and the rest of the ML team presently use and need. For instance, I assume you all are not parsing a dictionary from string form from a source dictionary that looks like this: {'list': [1, {'foo': 1}]}.

I would recommend the second case. It looks like this code handles dictionaries and pandas dataframes the best, so I would recommend you stick to that. We can iterate and add support for nested lists and mixed type lists as time goes on.

This is all just my interpretation based on how the code is written right now. If there are other cases you need covered, please let me know.

if ele.dtype == np.dtype("O"):
f.create_dataset(
f"{key}/{i}",
data=str(ele),
dtype=dt,
track_order=True,
)
elif all(isinstance(ele, dict) for ele in val):
# save dictionaries as groups recursively
for i, ele in enumerate(val):
group = f.create_group(f"{key}/{i}", track_order=True)
recursive_save(ele, group)
elif all(isinstance(ele, tuple) for ele in val):
# save tuples as np.array
for i, ele in enumerate(val):
val_array = np.array(ele)
f.create_dataset(
f"{key}/{i}", data=val_array, track_order=True
)
elif all(isinstance(ele, list) for ele in val):
# if it's a list of lists, save as np.array if homogeneous and type allows
# else save as strings
for i, ele in enumerate(val):
if all(isinstance(j, self.supported_types) for j in ele):
f.create_dataset(
f"{key}/{i}", data=np.array(ele), track_order=True
)
else:
f.create_dataset(
f"{key}/{i}",
data=str(ele),
dtype=dt,
track_order=True,
)
else:
for i, ele in enumerate(val):
# if it's a list of mixed types, save as strings
if isinstance(ele, str):
f.create_dataset(
f"{key}/{i}", data=ele, dtype=dt, track_order=True
)
else:
f.create_dataset(
f"{key}/{i}",
data=str(ele),
dtype=dt,
track_order=True,
)
elif isinstance(val, self.supported_types):
f.create_dataset(key, data=val, track_order=True)
elif isinstance(val, np.ndarray):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ndarrays and tuples below can also contain dictionaries, but that case is not handled here. I'd recommend making that lack of implementation explicit, such as with an error.

if val.dtype != np.dtype("O"):
f.create_dataset(key, data=val, track_order=True)
else:
f.create_dataset(key, data=str(val), dtype=dt, track_order=True)
elif isinstance(val, tuple):
val_array = np.array(val)
f.create_dataset(key, data=val_array, track_order=True)
elif isinstance(val, str):
# specify string dtype to avoid issues with encodings
f.create_dataset(key, data=val, dtype=dt, track_order=True)
elif isinstance(val, pd.DataFrame):
# save DataFrame as a group with datasets for columns
group = f.create_group(key)
group.attrs["pandas_type"] = "dataframe"
group.attrs["columns"] = list(val.columns)
for col in val.columns:
if val[col].dtype == np.dtype("O"):
try:
val[col] = val[col].astype("float64")
except ValueError:
val[col] = val[col].astype("string")
group.create_dataset(col, data=val[col].values)
else:
f.create_dataset(key, data=str(val), dtype=dt, track_order=True)

with h5py.File(filepath, "w") as file:
recursive_save(data, file)

def load(self, filepath):
"""
Load a dictionary from an HDF5 file.

Parameters
----------
filepath : str
The path to the file to load.

Returns
-------
dict
The dictionary loaded from the file.
"""

def recursive_load(f):
d = {"attrs": dict(f.attrs)} if f.attrs else {}
for key, val in f.items():
if isinstance(val, h5py.Group):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to elaborate on my previous comment a little more, h5py groups do two things here. Either they create pandas dataframes or dictionaries by default. We could also write a group that gets reconstructed as a list, which would be a nice way to reconstruct "{key}/{i}" type groups.

I would really recommend that we shrink this PR down. It's difficult to recover data from its printed string form programmatically, and the right way to approach those cases is to expand the code to dump and recover them appropriately, which would add a lot of bulk to this already long PR. Therefore, those parts can be cut.

I think having a generalized way to dump and load data is a good idea in the long run, but we can iterate on our current approach to groups to be general.

if (
"pandas_type" in val.attrs
and val.attrs["pandas_type"] == "dataframe"
):
# Load DataFrame from group
columns = val.attrs["columns"]
data = {}
for col in columns:
data[col] = val[col][:]
d[key] = pd.DataFrame(data)
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already use a pandas_type attribute above to tell the loader what type we're dealing with. Maybe we should add an attribute for dictionaries too?

d[key] = recursive_load(val)
elif isinstance(val, h5py.Dataset):
if isinstance(val[()], bytes):
d[key] = val[()].decode(self.string_dtype)
else:
d[key] = val[()]
return d

with h5py.File(filepath, "r") as file:
return recursive_load(file)
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ dependencies = [
"requests",
"pydantic",
"h5py",
"scikit-learn"
"scikit-learn",
"pandas"
]
description = "Tools to support high level application development at LCLS using Python"
dynamic = ["version"]
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pyyaml
requests
pydantic
h5py
pandas
scikit-learn
sphinx
sphinx_rtd_theme
Expand Down
Binary file added tests/datasets/images/numpy/test_images.npy
Binary file not shown.
174 changes: 174 additions & 0 deletions tests/unit_tests/lcls_tools/common/data/test_saver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import os

import numpy as np

from lcls_tools.common.measurements.screen_profile import (
ScreenBeamProfileMeasurementResult,
)
from lcls_tools.common.image.processing import ImageProcessor
from lcls_tools.common.image.fit import ImageProjectionFit
from lcls_tools.common.data.saver import H5Saver


class TestSaver:
def test_nans(self):
saver = H5Saver()
data = {
"a": np.nan,
"b": np.inf,
"c": -np.inf,
"d": [np.nan, np.inf, -np.inf],
"e": [np.nan, np.inf, -np.inf, 1.0],
"f": [np.nan, np.inf, -np.inf, "a"],
"g": {"a": np.nan, "b": np.inf, "c": -np.inf},
"h": "np.Nan",
"i": np.array((1.0, 2.0), dtype="O"),
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add assert statements here like in test_special_values below?

saver.dump(data, "test.h5")
os.remove("test.h5")

def test_screen_measurement_results(self):
# Load test data
images = np.load("tests/datasets/images/numpy/test_images.npy")

# Process data
image_processor = ImageProcessor()

processed_images = [image_processor.auto_process(image) for image in images]

rms_sizes = []
centroids = []
total_intensities = []
for image in processed_images:
fit_result = ImageProjectionFit().fit_image(image)
rms_sizes.append(fit_result.rms_size)
centroids.append(fit_result.centroid)
total_intensities.append(fit_result.total_intensity)

# Store results in ScreenBeamProfileMeasurementResult
result = ScreenBeamProfileMeasurementResult(
raw_images=images,
processed_images=processed_images,
rms_sizes=rms_sizes or None,
centroids=centroids or None,
total_intensities=total_intensities or None,
metadata={"info": "test"},
)

# Dump to H5
result_dict = result.model_dump()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already test to make sure that a saved dictionary has the same keys as the original when reloaded. Presumably screen beam measurement already tests to make sure the result has the right attributes on output. So long as both of those things are tested separately, this test doesn't test anything new.

saver = H5Saver()
saver.dump(
result_dict,
os.path.join("screen_test.h5"),
)

# Load H5
loaded_dict = saver.load("screen_test.h5")

# Check if the loaded dictionary is the same as the original
assert result_dict.keys() == loaded_dict.keys()
assert result_dict["metadata"] == loaded_dict["metadata"]
assert isinstance(loaded_dict["raw_images"], np.ndarray)
assert np.allclose(images, loaded_dict["raw_images"], rtol=1e-5)

mask = ~np.isnan(rms_sizes)
assert np.allclose(
np.asarray(rms_sizes)[mask], loaded_dict["rms_sizes"][mask], rtol=1e-5
)
mask = ~np.isnan(centroids)
assert np.allclose(
np.asarray(centroids)[mask], loaded_dict["centroids"][mask], rtol=1e-5
)
assert np.allclose(
total_intensities, loaded_dict["total_intensities"], rtol=1e-5
)

os.remove("screen_test.h5")

def test_basic_data_types(self):
saver = H5Saver()
data = {
"int": 42,
"float": 3.14,
"bool": True,
"string": "test",
"list": [1, 2, 3],
"tuple": (4, 5, 6),
"dict": {"a": 1, "b": 2},
"ndarray": np.array([7, 8, 9]),
}
saver.dump(data, "test_basic.h5")
loaded_data = saver.load("test_basic.h5")
os.remove("test_basic.h5")

assert data["int"] == loaded_data["int"]
assert data["float"] == loaded_data["float"]
assert data["bool"] == loaded_data["bool"]
assert data["string"] == loaded_data["string"]
assert data["list"] == loaded_data["list"].tolist()
assert (
list(data["tuple"]) == loaded_data["tuple"].tolist()
) # tuple are saved as arrays
assert data["dict"] == loaded_data["dict"]
assert np.array_equal(data["ndarray"], loaded_data["ndarray"])

def test_special_values(self):
saver = H5Saver()
data = {
"nan": np.nan,
"inf": np.inf,
"ninf": -np.inf,
"nan_list": [np.nan, np.inf, -np.inf],
}
saver.dump(data, "test_special.h5")
loaded_data = saver.load("test_special.h5")
os.remove("test_special.h5")

assert np.isnan(loaded_data["nan"])
assert np.isinf(loaded_data["inf"])
assert np.isneginf(loaded_data["ninf"])
assert np.isnan(loaded_data["nan_list"][0])
assert np.isinf(loaded_data["nan_list"][1])
assert np.isneginf(loaded_data["nan_list"][2])

def test_nested_structures(self):
saver = H5Saver()
data = {
"nested_dict": {"level1": {"level2": {"level3": "value"}}},
"nested_list": [[1, 2, 3], [4, 5, 6]],
}
saver.dump(data, "test_nested.h5")
loaded_data = saver.load("test_nested.h5")
os.remove("test_nested.h5")

assert data["nested_dict"] == loaded_data["nested_dict"]
for i in range(len(data["nested_list"])):
# lists of lists are saved as dicts
# here the lists are saved as nd.arrays
assert np.array_equal(
data["nested_list"][i], loaded_data["nested_list"][f"{i}"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's an example of the pattern I'd like to avoid. If we can add an attribute that tells the loader that we're dealing with a pandas dataframe, then there's no reason we can't do the same for a list.

)

def test_object_arrays(self):
saver = H5Saver()
data = {"object_array": np.array([1, "a", 3.14], dtype=object)}
saver.dump(data, "test_object_array.h5")
loaded_data = saver.load("test_object_array.h5")
os.remove("test_object_array.h5")

assert all(isinstance(item, str) for item in loaded_data["object_array"])

def test_list_of_ndarrays(self):
saver = H5Saver()
data = {"list_of_ndarrays": [np.array([1, 2, 3]), np.array([4, 5, 6])]}
saver.dump(data, "test_list_of_ndarrays.h5")
loaded_data = saver.load("test_list_of_ndarrays.h5")
os.remove("test_list_of_ndarrays.h5")

assert len(data["list_of_ndarrays"]) == len(loaded_data["list_of_ndarrays"])
for original, loaded in zip(
data["list_of_ndarrays"], loaded_data["list_of_ndarrays"].values()
):
# lists of ndarrays are saved as dicts
assert np.array_equal(original, loaded)
Loading