-
Notifications
You must be signed in to change notification settings - Fork 22
Adds h5py dump #221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Adds h5py dump #221
Changes from all commits
9ab9322
0ab7b7a
7f9f812
3d8c53c
c77fadd
896fcae
ce2afa5
0439fa2
b419d34
da55256
1ec3e55
b62db53
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
from typing import Dict, Any | ||
import h5py | ||
import numpy as np | ||
from pydantic import validate_call | ||
import pandas as pd | ||
|
||
|
||
class H5Saver: | ||
""" | ||
Class to dump and load dictionaries to and from HDF5 files. | ||
|
||
Methods | ||
------- | ||
dump(data, filepath) | ||
Dumps a dictionary to an HDF5 file. | ||
load(filepath) | ||
Loads a dictionary from an HDF5 file. | ||
""" | ||
|
||
def __init__(self): | ||
""" | ||
Initialize the H5Saver class. | ||
|
||
Parameters | ||
---------- | ||
string_dtype : str, optional | ||
The encoding to use when saving string data. Default is 'utf-8'. | ||
""" | ||
self.string_dtype = "utf-8" | ||
self.supported_types = (bool, int, float, np.integer, np.floating) | ||
|
||
@validate_call | ||
def dump(self, data: Dict[str, Any], filepath: str): | ||
""" | ||
Save a dictionary to an HDF5 file. 5s | ||
|
||
Parameters | ||
---------- | ||
data : Dict[str, Any] | ||
The dictionary to save. | ||
filepath : str | ||
The path to save the HDF5 file. | ||
|
||
Returns | ||
------- | ||
None | ||
""" | ||
dt = h5py.string_dtype(encoding=self.string_dtype) | ||
|
||
def recursive_save(d, f): | ||
for key, val in d.items(): | ||
if key == "attrs": | ||
f.attrs.update(val or h5py.Empty("f4")) | ||
elif isinstance(val, dict): | ||
group = f.create_group(key, track_order=True) | ||
recursive_save(val, group) | ||
elif isinstance(val, list): | ||
if all(isinstance(ele, self.supported_types) for ele in val): | ||
f.create_dataset(key, data=val, track_order=True) | ||
elif all(isinstance(ele, np.ndarray) for ele in val): | ||
# save np.arrays as datasets | ||
for i, ele in enumerate(val): | ||
f.create_dataset(f"{key}/{i}", data=ele, track_order=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd like to examine what this code is trying to accomplish by looking at the Here's a pretty basic case:
The result is that the original There's a couple approaches we could take to resolving this. We could create groups that contain information about the source class. This shouldn't be so hard to implement, but it would mean writing a lot more code. The second approach is to reduce the number of cases that this code handles to the ones that you and the rest of the ML team presently use and need. For instance, I assume you all are not parsing a dictionary from string form from a source dictionary that looks like this: I would recommend the second case. It looks like this code handles dictionaries and pandas dataframes the best, so I would recommend you stick to that. We can iterate and add support for nested lists and mixed type lists as time goes on. This is all just my interpretation based on how the code is written right now. If there are other cases you need covered, please let me know. |
||
if ele.dtype == np.dtype("O"): | ||
f.create_dataset( | ||
f"{key}/{i}", | ||
data=str(ele), | ||
dtype=dt, | ||
track_order=True, | ||
) | ||
elif all(isinstance(ele, dict) for ele in val): | ||
# save dictionaries as groups recursively | ||
for i, ele in enumerate(val): | ||
group = f.create_group(f"{key}/{i}", track_order=True) | ||
recursive_save(ele, group) | ||
elif all(isinstance(ele, tuple) for ele in val): | ||
# save tuples as np.array | ||
for i, ele in enumerate(val): | ||
val_array = np.array(ele) | ||
f.create_dataset( | ||
f"{key}/{i}", data=val_array, track_order=True | ||
) | ||
elif all(isinstance(ele, list) for ele in val): | ||
# if it's a list of lists, save as np.array if homogeneous and type allows | ||
# else save as strings | ||
for i, ele in enumerate(val): | ||
if all(isinstance(j, self.supported_types) for j in ele): | ||
f.create_dataset( | ||
f"{key}/{i}", data=np.array(ele), track_order=True | ||
) | ||
else: | ||
f.create_dataset( | ||
f"{key}/{i}", | ||
data=str(ele), | ||
dtype=dt, | ||
track_order=True, | ||
) | ||
else: | ||
for i, ele in enumerate(val): | ||
# if it's a list of mixed types, save as strings | ||
if isinstance(ele, str): | ||
f.create_dataset( | ||
f"{key}/{i}", data=ele, dtype=dt, track_order=True | ||
) | ||
else: | ||
f.create_dataset( | ||
f"{key}/{i}", | ||
data=str(ele), | ||
dtype=dt, | ||
track_order=True, | ||
) | ||
elif isinstance(val, self.supported_types): | ||
f.create_dataset(key, data=val, track_order=True) | ||
elif isinstance(val, np.ndarray): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ndarrays and tuples below can also contain dictionaries, but that case is not handled here. I'd recommend making that lack of implementation explicit, such as with an error. |
||
if val.dtype != np.dtype("O"): | ||
f.create_dataset(key, data=val, track_order=True) | ||
else: | ||
f.create_dataset(key, data=str(val), dtype=dt, track_order=True) | ||
elif isinstance(val, tuple): | ||
val_array = np.array(val) | ||
f.create_dataset(key, data=val_array, track_order=True) | ||
elif isinstance(val, str): | ||
# specify string dtype to avoid issues with encodings | ||
f.create_dataset(key, data=val, dtype=dt, track_order=True) | ||
elif isinstance(val, pd.DataFrame): | ||
# save DataFrame as a group with datasets for columns | ||
group = f.create_group(key) | ||
group.attrs["pandas_type"] = "dataframe" | ||
group.attrs["columns"] = list(val.columns) | ||
for col in val.columns: | ||
if val[col].dtype == np.dtype("O"): | ||
try: | ||
val[col] = val[col].astype("float64") | ||
except ValueError: | ||
val[col] = val[col].astype("string") | ||
group.create_dataset(col, data=val[col].values) | ||
else: | ||
pluflou marked this conversation as resolved.
Show resolved
Hide resolved
|
||
f.create_dataset(key, data=str(val), dtype=dt, track_order=True) | ||
|
||
with h5py.File(filepath, "w") as file: | ||
recursive_save(data, file) | ||
|
||
def load(self, filepath): | ||
""" | ||
Load a dictionary from an HDF5 file. | ||
|
||
Parameters | ||
---------- | ||
filepath : str | ||
The path to the file to load. | ||
|
||
Returns | ||
------- | ||
dict | ||
The dictionary loaded from the file. | ||
""" | ||
|
||
def recursive_load(f): | ||
d = {"attrs": dict(f.attrs)} if f.attrs else {} | ||
for key, val in f.items(): | ||
if isinstance(val, h5py.Group): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to elaborate on my previous comment a little more, h5py groups do two things here. Either they create pandas dataframes or dictionaries by default. We could also write a group that gets reconstructed as a list, which would be a nice way to reconstruct "{key}/{i}" type groups. I would really recommend that we shrink this PR down. It's difficult to recover data from its printed string form programmatically, and the right way to approach those cases is to expand the code to dump and recover them appropriately, which would add a lot of bulk to this already long PR. Therefore, those parts can be cut. I think having a generalized way to dump and load data is a good idea in the long run, but we can iterate on our current approach to groups to be general. |
||
if ( | ||
"pandas_type" in val.attrs | ||
and val.attrs["pandas_type"] == "dataframe" | ||
): | ||
# Load DataFrame from group | ||
columns = val.attrs["columns"] | ||
data = {} | ||
for col in columns: | ||
data[col] = val[col][:] | ||
d[key] = pd.DataFrame(data) | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We already use a |
||
d[key] = recursive_load(val) | ||
elif isinstance(val, h5py.Dataset): | ||
if isinstance(val[()], bytes): | ||
d[key] = val[()].decode(self.string_dtype) | ||
else: | ||
d[key] = val[()] | ||
return d | ||
|
||
with h5py.File(filepath, "r") as file: | ||
return recursive_load(file) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ pyyaml | |
requests | ||
pydantic | ||
h5py | ||
pandas | ||
scikit-learn | ||
sphinx | ||
sphinx_rtd_theme | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
import os | ||
|
||
import numpy as np | ||
|
||
from lcls_tools.common.measurements.screen_profile import ( | ||
ScreenBeamProfileMeasurementResult, | ||
) | ||
from lcls_tools.common.image.processing import ImageProcessor | ||
from lcls_tools.common.image.fit import ImageProjectionFit | ||
from lcls_tools.common.data.saver import H5Saver | ||
|
||
|
||
class TestSaver: | ||
def test_nans(self): | ||
saver = H5Saver() | ||
data = { | ||
"a": np.nan, | ||
"b": np.inf, | ||
"c": -np.inf, | ||
"d": [np.nan, np.inf, -np.inf], | ||
"e": [np.nan, np.inf, -np.inf, 1.0], | ||
"f": [np.nan, np.inf, -np.inf, "a"], | ||
"g": {"a": np.nan, "b": np.inf, "c": -np.inf}, | ||
"h": "np.Nan", | ||
"i": np.array((1.0, 2.0), dtype="O"), | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we add assert statements here like in |
||
saver.dump(data, "test.h5") | ||
os.remove("test.h5") | ||
|
||
def test_screen_measurement_results(self): | ||
# Load test data | ||
images = np.load("tests/datasets/images/numpy/test_images.npy") | ||
|
||
# Process data | ||
image_processor = ImageProcessor() | ||
|
||
processed_images = [image_processor.auto_process(image) for image in images] | ||
|
||
rms_sizes = [] | ||
centroids = [] | ||
total_intensities = [] | ||
for image in processed_images: | ||
fit_result = ImageProjectionFit().fit_image(image) | ||
rms_sizes.append(fit_result.rms_size) | ||
centroids.append(fit_result.centroid) | ||
total_intensities.append(fit_result.total_intensity) | ||
|
||
# Store results in ScreenBeamProfileMeasurementResult | ||
result = ScreenBeamProfileMeasurementResult( | ||
raw_images=images, | ||
processed_images=processed_images, | ||
rms_sizes=rms_sizes or None, | ||
centroids=centroids or None, | ||
total_intensities=total_intensities or None, | ||
metadata={"info": "test"}, | ||
) | ||
|
||
# Dump to H5 | ||
result_dict = result.model_dump() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We already test to make sure that a saved dictionary has the same keys as the original when reloaded. Presumably screen beam measurement already tests to make sure the result has the right attributes on output. So long as both of those things are tested separately, this test doesn't test anything new. |
||
saver = H5Saver() | ||
saver.dump( | ||
result_dict, | ||
os.path.join("screen_test.h5"), | ||
) | ||
|
||
# Load H5 | ||
loaded_dict = saver.load("screen_test.h5") | ||
|
||
# Check if the loaded dictionary is the same as the original | ||
assert result_dict.keys() == loaded_dict.keys() | ||
assert result_dict["metadata"] == loaded_dict["metadata"] | ||
assert isinstance(loaded_dict["raw_images"], np.ndarray) | ||
assert np.allclose(images, loaded_dict["raw_images"], rtol=1e-5) | ||
|
||
mask = ~np.isnan(rms_sizes) | ||
assert np.allclose( | ||
np.asarray(rms_sizes)[mask], loaded_dict["rms_sizes"][mask], rtol=1e-5 | ||
) | ||
mask = ~np.isnan(centroids) | ||
assert np.allclose( | ||
np.asarray(centroids)[mask], loaded_dict["centroids"][mask], rtol=1e-5 | ||
) | ||
assert np.allclose( | ||
total_intensities, loaded_dict["total_intensities"], rtol=1e-5 | ||
) | ||
|
||
os.remove("screen_test.h5") | ||
|
||
def test_basic_data_types(self): | ||
saver = H5Saver() | ||
data = { | ||
"int": 42, | ||
"float": 3.14, | ||
"bool": True, | ||
"string": "test", | ||
"list": [1, 2, 3], | ||
"tuple": (4, 5, 6), | ||
"dict": {"a": 1, "b": 2}, | ||
"ndarray": np.array([7, 8, 9]), | ||
} | ||
saver.dump(data, "test_basic.h5") | ||
loaded_data = saver.load("test_basic.h5") | ||
os.remove("test_basic.h5") | ||
|
||
assert data["int"] == loaded_data["int"] | ||
assert data["float"] == loaded_data["float"] | ||
assert data["bool"] == loaded_data["bool"] | ||
assert data["string"] == loaded_data["string"] | ||
assert data["list"] == loaded_data["list"].tolist() | ||
assert ( | ||
list(data["tuple"]) == loaded_data["tuple"].tolist() | ||
) # tuple are saved as arrays | ||
assert data["dict"] == loaded_data["dict"] | ||
assert np.array_equal(data["ndarray"], loaded_data["ndarray"]) | ||
|
||
def test_special_values(self): | ||
saver = H5Saver() | ||
data = { | ||
"nan": np.nan, | ||
"inf": np.inf, | ||
"ninf": -np.inf, | ||
"nan_list": [np.nan, np.inf, -np.inf], | ||
} | ||
saver.dump(data, "test_special.h5") | ||
loaded_data = saver.load("test_special.h5") | ||
os.remove("test_special.h5") | ||
|
||
assert np.isnan(loaded_data["nan"]) | ||
assert np.isinf(loaded_data["inf"]) | ||
assert np.isneginf(loaded_data["ninf"]) | ||
assert np.isnan(loaded_data["nan_list"][0]) | ||
assert np.isinf(loaded_data["nan_list"][1]) | ||
assert np.isneginf(loaded_data["nan_list"][2]) | ||
|
||
def test_nested_structures(self): | ||
saver = H5Saver() | ||
data = { | ||
"nested_dict": {"level1": {"level2": {"level3": "value"}}}, | ||
"nested_list": [[1, 2, 3], [4, 5, 6]], | ||
} | ||
saver.dump(data, "test_nested.h5") | ||
loaded_data = saver.load("test_nested.h5") | ||
os.remove("test_nested.h5") | ||
|
||
assert data["nested_dict"] == loaded_data["nested_dict"] | ||
for i in range(len(data["nested_list"])): | ||
# lists of lists are saved as dicts | ||
# here the lists are saved as nd.arrays | ||
assert np.array_equal( | ||
data["nested_list"][i], loaded_data["nested_list"][f"{i}"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here's an example of the pattern I'd like to avoid. If we can add an attribute that tells the loader that we're dealing with a pandas dataframe, then there's no reason we can't do the same for a list. |
||
) | ||
|
||
def test_object_arrays(self): | ||
saver = H5Saver() | ||
data = {"object_array": np.array([1, "a", 3.14], dtype=object)} | ||
saver.dump(data, "test_object_array.h5") | ||
loaded_data = saver.load("test_object_array.h5") | ||
os.remove("test_object_array.h5") | ||
|
||
assert all(isinstance(item, str) for item in loaded_data["object_array"]) | ||
|
||
def test_list_of_ndarrays(self): | ||
saver = H5Saver() | ||
data = {"list_of_ndarrays": [np.array([1, 2, 3]), np.array([4, 5, 6])]} | ||
saver.dump(data, "test_list_of_ndarrays.h5") | ||
loaded_data = saver.load("test_list_of_ndarrays.h5") | ||
os.remove("test_list_of_ndarrays.h5") | ||
|
||
assert len(data["list_of_ndarrays"]) == len(loaded_data["list_of_ndarrays"]) | ||
for original, loaded in zip( | ||
data["list_of_ndarrays"], loaded_data["list_of_ndarrays"].values() | ||
): | ||
# lists of ndarrays are saved as dicts | ||
assert np.array_equal(original, loaded) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you rename this to something like
h5str
?dt
is not descriptive of what this does.