Skip to content

Bugfix/write to hdf5 #39

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 239 additions & 0 deletions proteka/dataset/ensemble_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
from proteka import Ensemble, UnitSystem, Quantity
from proteka.quantity.quantity_shapes import PerFrameQuantity
from proteka.dataset.ensemble import HDF5Group
from proteka.metrics import Featurizer
from proteka.metrics.utils import get_CLN_trajectory
import pytest
import numpy as np
import mdtraj as md

from .top_utils import json2top, top2json

Expand All @@ -21,6 +24,27 @@ def example_ensemble(example_json_topology):
return ensemble


@pytest.fixture
def cln_example_ensemble():
"""Create a noisy CLN ensemble."""
cln_trajs = [get_CLN_trajectory(seed=i) for i in range(10)]
trajectory_slices = {}
current_frames = 0
coords = []
for i, traj in enumerate(cln_trajs):
frames = traj.xyz.shape[0]
coords.append(traj.xyz)
slc = slice(current_frames, current_frames + frames, 1)
trajectory_slices["traj_{i}"] = slc
current_frames += frames
ensemble = Ensemble(
name="cln_example_ensemble",
top=cln_trajs[0].top,
coords=np.concatenate(coords),
)
return ensemble


def test_casting(example_ensemble):
assert example_ensemble["coords"].unit == "nanometers"
example_ensemble.forces = np.zeros((10, example_ensemble.n_atoms, 3))
Expand Down Expand Up @@ -131,6 +155,221 @@ def test_ensemble_to_h5(example_ensemble, tmpdir):
assert ensemble2["custom_field"].unit == "dimensionless"


def test_featurized_all_ensemble_to_h5(cln_example_ensemble, tmpdir):
"""Test saving and loading ensembles."""
ensemble = cln_example_ensemble
ensemble.forces = np.zeros((ensemble.n_frames, ensemble.n_atoms, 3))
ensemble.custom_field = np.zeros((5, ensemble.n_atoms, 3))
# random weights inspired by MSM analysis.
weights = np.random.rand(ensemble.n_frames)
ensemble.set_quantity("weights", weights)
# quantities from a featurizer.
root_dir = Path(__file__).parent.parent.parent
cln_path = (
root_dir / "examples" / "example_dataset_files" / "cln_folded.pdb"
)
cln_native_structure = md.load_pdb(cln_path)
feat = Featurizer()
feat.add_rg(ensemble)
feat.add_fraction_native_contacts(
ensemble, reference_structure=cln_native_structure
)
feat.add_rmsd(ensemble, reference_structure=cln_native_structure)
feat.add_ca_distances(ensemble, offset=1, subset_selection="name CA")
with h5py.File(tmpdir / "test.h5", "w") as f:
ensemble.write_to_hdf5(f, name="example_ensemble")

with h5py.File(tmpdir / "test.h5", "r") as f:
group = f["example_ensemble"]
ensemble2 = Ensemble.from_hdf5(group)

assert ensemble2.n_frames == ensemble.n_frames
assert np.allclose(ensemble2.coords, ensemble.coords)
assert np.allclose(ensemble2.forces, ensemble.forces)
assert np.allclose(ensemble2.custom_field, ensemble.custom_field)
assert set(ensemble2.list_quantities()) == set(ensemble.list_quantities())
assert np.allclose(
ensemble2.get_quantity("weights").raw_value,
ensemble.get_quantity("weights").raw_value,
)
assert np.allclose(
ensemble2.get_quantity("rg").raw_value,
ensemble.get_quantity("rg").raw_value,
)
assert np.allclose(
ensemble2.get_quantity("rmsd").raw_value,
ensemble.get_quantity("rmsd").raw_value,
)
assert np.allclose(
ensemble2.get_quantity("fraction_native_contacts").raw_value,
ensemble.get_quantity("fraction_native_contacts").raw_value,
)
assert np.allclose(
ensemble2.get_quantity("ca_distances").raw_value,
ensemble.get_quantity("ca_distances").raw_value,
)
assert ensemble2.top == ensemble.top
assert ensemble2.name == ensemble.name
assert ensemble2["custom_field"].unit == "dimensionless"


def test_featurized_weights_ensemble_to_h5(cln_example_ensemble, tmpdir):
"""Test saving and loading ensembles."""
ensemble = cln_example_ensemble
ensemble.forces = np.zeros((ensemble.n_frames, ensemble.n_atoms, 3))
ensemble.custom_field = np.zeros((5, ensemble.n_atoms, 3))
# random weights inspired by MSM analysis.
weights = np.random.rand(ensemble.n_frames)
ensemble.set_quantity("weights", weights)

with h5py.File(tmpdir / "test.h5", "w") as f:
ensemble.write_to_hdf5(f, name="example_ensemble")

with h5py.File(tmpdir / "test.h5", "r") as f:
group = f["example_ensemble"]
ensemble2 = Ensemble.from_hdf5(group)

assert ensemble2.n_frames == ensemble.n_frames
assert np.allclose(ensemble2.coords, ensemble.coords)
assert np.allclose(ensemble2.forces, ensemble.forces)
assert np.allclose(ensemble2.custom_field, ensemble.custom_field)
assert set(ensemble2.list_quantities()) == set(ensemble.list_quantities())
assert np.allclose(
ensemble2.get_quantity("weights").raw_value,
ensemble.get_quantity("weights").raw_value,
)
assert ensemble2.top == ensemble.top
assert ensemble2.name == ensemble.name
assert ensemble2["custom_field"].unit == "dimensionless"


def test_featurized_rg_ensemble_to_h5(cln_example_ensemble, tmpdir):
"""Test saving and loading ensembles."""
ensemble = cln_example_ensemble
ensemble.forces = np.zeros((ensemble.n_frames, ensemble.n_atoms, 3))
ensemble.custom_field = np.zeros((5, ensemble.n_atoms, 3))
feat = Featurizer()
feat.add_rg(ensemble)
with h5py.File(tmpdir / "test.h5", "w") as f:
ensemble.write_to_hdf5(f, name="example_ensemble")

with h5py.File(tmpdir / "test.h5", "r") as f:
group = f["example_ensemble"]
ensemble2 = Ensemble.from_hdf5(group)

assert ensemble2.n_frames == ensemble.n_frames
assert np.allclose(ensemble2.coords, ensemble.coords)
assert np.allclose(ensemble2.forces, ensemble.forces)
assert np.allclose(ensemble2.custom_field, ensemble.custom_field)
assert set(ensemble2.list_quantities()) == set(ensemble.list_quantities())
assert np.allclose(
ensemble2.get_quantity("rg").raw_value,
ensemble.get_quantity("rg").raw_value,
)
assert ensemble2.top == ensemble.top
assert ensemble2.name == ensemble.name
assert ensemble2["custom_field"].unit == "dimensionless"


def test_featurized_fraction_native_contacs_ensemble_to_h5(
cln_example_ensemble, tmpdir
):
"""Test saving and loading ensembles."""
ensemble = cln_example_ensemble
ensemble.forces = np.zeros((ensemble.n_frames, ensemble.n_atoms, 3))
ensemble.custom_field = np.zeros((5, ensemble.n_atoms, 3))
root_dir = Path(__file__).parent.parent.parent
cln_path = (
root_dir / "examples" / "example_dataset_files" / "cln_folded.pdb"
)
cln_native_structure = md.load_pdb(cln_path)
feat = Featurizer()
feat.add_fraction_native_contacts(
ensemble, reference_structure=cln_native_structure
)
with h5py.File(tmpdir / "test.h5", "w") as f:
ensemble.write_to_hdf5(f, name="example_ensemble")

with h5py.File(tmpdir / "test.h5", "r") as f:
group = f["example_ensemble"]
ensemble2 = Ensemble.from_hdf5(group)

assert ensemble2.n_frames == ensemble.n_frames
assert np.allclose(ensemble2.coords, ensemble.coords)
assert np.allclose(ensemble2.forces, ensemble.forces)
assert np.allclose(ensemble2.custom_field, ensemble.custom_field)
assert set(ensemble2.list_quantities()) == set(ensemble.list_quantities())
assert np.allclose(
ensemble2.get_quantity("fraction_native_contacts").raw_value,
ensemble.get_quantity("fraction_native_contacts").raw_value,
)
assert ensemble2.top == ensemble.top
assert ensemble2.name == ensemble.name
assert ensemble2["custom_field"].unit == "dimensionless"


def test_featurized_rmsd_ensemble_to_h5(cln_example_ensemble, tmpdir):
"""Test saving and loading ensembles."""
ensemble = cln_example_ensemble
ensemble.forces = np.zeros((ensemble.n_frames, ensemble.n_atoms, 3))
ensemble.custom_field = np.zeros((5, ensemble.n_atoms, 3))
root_dir = Path(__file__).parent.parent.parent
cln_path = (
root_dir / "examples" / "example_dataset_files" / "cln_folded.pdb"
)
cln_native_structure = md.load_pdb(cln_path)
feat = Featurizer()
feat.add_rmsd(ensemble, reference_structure=cln_native_structure)
with h5py.File(tmpdir / "test.h5", "w") as f:
ensemble.write_to_hdf5(f, name="example_ensemble")

with h5py.File(tmpdir / "test.h5", "r") as f:
group = f["example_ensemble"]
ensemble2 = Ensemble.from_hdf5(group)

assert ensemble2.n_frames == ensemble.n_frames
assert np.allclose(ensemble2.coords, ensemble.coords)
assert np.allclose(ensemble2.forces, ensemble.forces)
assert np.allclose(ensemble2.custom_field, ensemble.custom_field)
assert set(ensemble2.list_quantities()) == set(ensemble.list_quantities())
assert np.allclose(
ensemble2.get_quantity("rmsd").raw_value,
ensemble.get_quantity("rmsd").raw_value,
)
assert ensemble2.top == ensemble.top
assert ensemble2.name == ensemble.name
assert ensemble2["custom_field"].unit == "dimensionless"


def test_featurized_ca_dist_ensemble_to_h5(cln_example_ensemble, tmpdir):
"""Test saving and loading ensembles."""
ensemble = cln_example_ensemble
ensemble.forces = np.zeros((ensemble.n_frames, ensemble.n_atoms, 3))
ensemble.custom_field = np.zeros((5, ensemble.n_atoms, 3))
# quantities from a featurizer.
feat = Featurizer()
feat.add_ca_distances(ensemble, offset=1, subset_selection="name CA")
with h5py.File(tmpdir / "test.h5", "w") as f:
ensemble.write_to_hdf5(f, name="example_ensemble")

with h5py.File(tmpdir / "test.h5", "r") as f:
group = f["example_ensemble"]
ensemble2 = Ensemble.from_hdf5(group)

assert ensemble2.n_frames == ensemble.n_frames
assert np.allclose(ensemble2.coords, ensemble.coords)
assert np.allclose(ensemble2.forces, ensemble.forces)
assert np.allclose(ensemble2.custom_field, ensemble.custom_field)
assert set(ensemble2.list_quantities()) == set(ensemble.list_quantities())
assert np.allclose(
ensemble2.get_quantity("ca_distances").raw_value,
ensemble.get_quantity("ca_distances").raw_value,
)
assert ensemble2.top == ensemble.top
assert ensemble2.name == ensemble.name
assert ensemble2["custom_field"].unit == "dimensionless"


def test_ensemble_with_extra_builtin_quantity(example_ensemble):
unit_system = UnitSystem(
"Angstrom",
Expand Down
39 changes: 20 additions & 19 deletions proteka/metrics/featurizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import warnings
import numpy as np
import mdtraj as md
from typing import Dict, Optional, List, Tuple
from ..dataset import Ensemble
from ..quantity import Quantity
from ..dataset.top_utils import top2json
Expand Down Expand Up @@ -484,10 +483,8 @@ def add_rmsd(

metadata = {
"feature": "rmsd",
"reference_structure": {
"coords": ref_coords,
"top": ref_top,
},
"reference_structure_coords": ref_coords,
"reference_structure_top": ref_top,
"atom_selection": atom_selection,
}

Expand Down Expand Up @@ -881,11 +878,9 @@ def add_fraction_native_contacts(

metadata = {
"feature": "fraction_native_contacts",
"reference_structure": {
"coords": native_coords,
"top": native_top,
"atomistic": use_atomistic_reference,
},
"reference_structure_coords": native_coords,
"reference_structure_top": top2json(native_top),
"reference_structure_is_atomistic": use_atomistic_reference,
"atom_selection": atom_selection,
"beta": beta,
"lam": lam,
Expand Down Expand Up @@ -1132,16 +1127,19 @@ def add_temporary_feature(
@staticmethod
def _reference_structure_equality(
input_structure: md.Trajectory,
serialized_structure: Dict,
stored_coords: Union[List, np.ndarray],
stored_top: str,
) -> bool:
"""Helper method for testing reference structure serialized equality for RMSD recomputation

Parameters
----------
input_structures:
input MDTraj single frame Trajectory for proposed RMSD calculations
serialized_structure:
Serialized reference structure
stored_coords:
Saved reference structure coordinates
stored_top:
Saved reference structure topology

Returns
-------
Expand All @@ -1150,14 +1148,13 @@ def _reference_structure_equality(
the same, True is returned. Else, False is returned.
"""

ref_coords = input_structure.xyz.tolist()
ref_coords = np.array(input_structure.xyz)
stored_coords = np.array(stored_coords)
ref_top = top2json(input_structure.topology)
stored_coords = serialized_structure["coords"]
stored_top = serialized_structure["top"]

equals = []
equals.append(stored_top == ref_top)
equals.append(stored_coords == ref_coords)
equals.append(np.allclose(stored_coords, ref_coords))
return all(equals)

@staticmethod
Expand Down Expand Up @@ -1238,7 +1235,8 @@ def get_feature(
reference_structure = args[0]
if not Featurizer._reference_structure_equality(
reference_structure,
ensemble[feature].metadata["reference_structure"],
ensemble[feature].metadata["reference_structure_coords"],
ensemble[feature].metadata["reference_structure_top"],
):
recompute = True

Expand All @@ -1251,7 +1249,10 @@ def get_feature(
reference_structure = kwargs[key]
if not Featurizer._reference_structure_equality(
reference_structure,
ensemble[feature].metadata["reference_structure"],
ensemble[feature].metadata[
"reference_structure_coords"
],
ensemble[feature].metadata["reference_structure_top"],
):
recompute = True
break
Expand Down
9 changes: 7 additions & 2 deletions proteka/quantity/meta_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,17 @@ def write_to_hdf5(self, h5_node, name=None):
None
"""

# local hack for storing strings
value_to_save = self._value
if value_to_save.dtype == np.dtype("O"):
value_to_save = np.asarray(str(value_to_save[()]), dtype="O")

def overwrite(dataset_node, input_pattern="h5_node"):
# overwrite an existing Dataset in HDF5 file
warn(
f"Input `{input_pattern}` correponds to existing Dataset with name {dataset_node.name}. Overwritting..."
)
dataset_node[...] = self._value
dataset_node[...] = value_to_save
for k, v in self.metadata.items():
dataset_node.attrs[k] = v

Expand All @@ -182,7 +187,7 @@ def overwrite(dataset_node, input_pattern="h5_node"):
overwrite(h5_node[name], "h5_node[name]")
else:
# create a new Dataset under h5_node
h5_node[name] = self._value
h5_node[name] = value_to_save
for k, v in self.metadata.items():
h5_node[name].attrs[k] = v
else:
Expand Down