diff --git a/proteka/dataset/ensemble_test.py b/proteka/dataset/ensemble_test.py index b75fdf1..148e062 100644 --- a/proteka/dataset/ensemble_test.py +++ b/proteka/dataset/ensemble_test.py @@ -4,8 +4,11 @@ from proteka import Ensemble, UnitSystem, Quantity from proteka.quantity.quantity_shapes import PerFrameQuantity from proteka.dataset.ensemble import HDF5Group +from proteka.metrics import Featurizer +from proteka.metrics.utils import get_CLN_trajectory import pytest import numpy as np +import mdtraj as md from .top_utils import json2top, top2json @@ -21,6 +24,27 @@ def example_ensemble(example_json_topology): return ensemble +@pytest.fixture +def cln_example_ensemble(): + """Create a noisy CLN ensemble.""" + cln_trajs = [get_CLN_trajectory(seed=i) for i in range(10)] + trajectory_slices = {} + current_frames = 0 + coords = [] + for i, traj in enumerate(cln_trajs): + frames = traj.xyz.shape[0] + coords.append(traj.xyz) + slc = slice(current_frames, current_frames + frames, 1) + trajectory_slices["traj_{i}"] = slc + current_frames += frames + ensemble = Ensemble( + name="cln_example_ensemble", + top=cln_trajs[0].top, + coords=np.concatenate(coords), + ) + return ensemble + + def test_casting(example_ensemble): assert example_ensemble["coords"].unit == "nanometers" example_ensemble.forces = np.zeros((10, example_ensemble.n_atoms, 3)) @@ -131,6 +155,221 @@ def test_ensemble_to_h5(example_ensemble, tmpdir): assert ensemble2["custom_field"].unit == "dimensionless" +def test_featurized_all_ensemble_to_h5(cln_example_ensemble, tmpdir): + """Test saving and loading ensembles.""" + ensemble = cln_example_ensemble + ensemble.forces = np.zeros((ensemble.n_frames, ensemble.n_atoms, 3)) + ensemble.custom_field = np.zeros((5, ensemble.n_atoms, 3)) + # random weights inspired by MSM analysis. + weights = np.random.rand(ensemble.n_frames) + ensemble.set_quantity("weights", weights) + # quantities from a featurizer. + root_dir = Path(__file__).parent.parent.parent + cln_path = ( + root_dir / "examples" / "example_dataset_files" / "cln_folded.pdb" + ) + cln_native_structure = md.load_pdb(cln_path) + feat = Featurizer() + feat.add_rg(ensemble) + feat.add_fraction_native_contacts( + ensemble, reference_structure=cln_native_structure + ) + feat.add_rmsd(ensemble, reference_structure=cln_native_structure) + feat.add_ca_distances(ensemble, offset=1, subset_selection="name CA") + with h5py.File(tmpdir / "test.h5", "w") as f: + ensemble.write_to_hdf5(f, name="example_ensemble") + + with h5py.File(tmpdir / "test.h5", "r") as f: + group = f["example_ensemble"] + ensemble2 = Ensemble.from_hdf5(group) + + assert ensemble2.n_frames == ensemble.n_frames + assert np.allclose(ensemble2.coords, ensemble.coords) + assert np.allclose(ensemble2.forces, ensemble.forces) + assert np.allclose(ensemble2.custom_field, ensemble.custom_field) + assert set(ensemble2.list_quantities()) == set(ensemble.list_quantities()) + assert np.allclose( + ensemble2.get_quantity("weights").raw_value, + ensemble.get_quantity("weights").raw_value, + ) + assert np.allclose( + ensemble2.get_quantity("rg").raw_value, + ensemble.get_quantity("rg").raw_value, + ) + assert np.allclose( + ensemble2.get_quantity("rmsd").raw_value, + ensemble.get_quantity("rmsd").raw_value, + ) + assert np.allclose( + ensemble2.get_quantity("fraction_native_contacts").raw_value, + ensemble.get_quantity("fraction_native_contacts").raw_value, + ) + assert np.allclose( + ensemble2.get_quantity("ca_distances").raw_value, + ensemble.get_quantity("ca_distances").raw_value, + ) + assert ensemble2.top == ensemble.top + assert ensemble2.name == ensemble.name + assert ensemble2["custom_field"].unit == "dimensionless" + + +def test_featurized_weights_ensemble_to_h5(cln_example_ensemble, tmpdir): + """Test saving and loading ensembles.""" + ensemble = cln_example_ensemble + ensemble.forces = np.zeros((ensemble.n_frames, ensemble.n_atoms, 3)) + ensemble.custom_field = np.zeros((5, ensemble.n_atoms, 3)) + # random weights inspired by MSM analysis. + weights = np.random.rand(ensemble.n_frames) + ensemble.set_quantity("weights", weights) + + with h5py.File(tmpdir / "test.h5", "w") as f: + ensemble.write_to_hdf5(f, name="example_ensemble") + + with h5py.File(tmpdir / "test.h5", "r") as f: + group = f["example_ensemble"] + ensemble2 = Ensemble.from_hdf5(group) + + assert ensemble2.n_frames == ensemble.n_frames + assert np.allclose(ensemble2.coords, ensemble.coords) + assert np.allclose(ensemble2.forces, ensemble.forces) + assert np.allclose(ensemble2.custom_field, ensemble.custom_field) + assert set(ensemble2.list_quantities()) == set(ensemble.list_quantities()) + assert np.allclose( + ensemble2.get_quantity("weights").raw_value, + ensemble.get_quantity("weights").raw_value, + ) + assert ensemble2.top == ensemble.top + assert ensemble2.name == ensemble.name + assert ensemble2["custom_field"].unit == "dimensionless" + + +def test_featurized_rg_ensemble_to_h5(cln_example_ensemble, tmpdir): + """Test saving and loading ensembles.""" + ensemble = cln_example_ensemble + ensemble.forces = np.zeros((ensemble.n_frames, ensemble.n_atoms, 3)) + ensemble.custom_field = np.zeros((5, ensemble.n_atoms, 3)) + feat = Featurizer() + feat.add_rg(ensemble) + with h5py.File(tmpdir / "test.h5", "w") as f: + ensemble.write_to_hdf5(f, name="example_ensemble") + + with h5py.File(tmpdir / "test.h5", "r") as f: + group = f["example_ensemble"] + ensemble2 = Ensemble.from_hdf5(group) + + assert ensemble2.n_frames == ensemble.n_frames + assert np.allclose(ensemble2.coords, ensemble.coords) + assert np.allclose(ensemble2.forces, ensemble.forces) + assert np.allclose(ensemble2.custom_field, ensemble.custom_field) + assert set(ensemble2.list_quantities()) == set(ensemble.list_quantities()) + assert np.allclose( + ensemble2.get_quantity("rg").raw_value, + ensemble.get_quantity("rg").raw_value, + ) + assert ensemble2.top == ensemble.top + assert ensemble2.name == ensemble.name + assert ensemble2["custom_field"].unit == "dimensionless" + + +def test_featurized_fraction_native_contacs_ensemble_to_h5( + cln_example_ensemble, tmpdir +): + """Test saving and loading ensembles.""" + ensemble = cln_example_ensemble + ensemble.forces = np.zeros((ensemble.n_frames, ensemble.n_atoms, 3)) + ensemble.custom_field = np.zeros((5, ensemble.n_atoms, 3)) + root_dir = Path(__file__).parent.parent.parent + cln_path = ( + root_dir / "examples" / "example_dataset_files" / "cln_folded.pdb" + ) + cln_native_structure = md.load_pdb(cln_path) + feat = Featurizer() + feat.add_fraction_native_contacts( + ensemble, reference_structure=cln_native_structure + ) + with h5py.File(tmpdir / "test.h5", "w") as f: + ensemble.write_to_hdf5(f, name="example_ensemble") + + with h5py.File(tmpdir / "test.h5", "r") as f: + group = f["example_ensemble"] + ensemble2 = Ensemble.from_hdf5(group) + + assert ensemble2.n_frames == ensemble.n_frames + assert np.allclose(ensemble2.coords, ensemble.coords) + assert np.allclose(ensemble2.forces, ensemble.forces) + assert np.allclose(ensemble2.custom_field, ensemble.custom_field) + assert set(ensemble2.list_quantities()) == set(ensemble.list_quantities()) + assert np.allclose( + ensemble2.get_quantity("fraction_native_contacts").raw_value, + ensemble.get_quantity("fraction_native_contacts").raw_value, + ) + assert ensemble2.top == ensemble.top + assert ensemble2.name == ensemble.name + assert ensemble2["custom_field"].unit == "dimensionless" + + +def test_featurized_rmsd_ensemble_to_h5(cln_example_ensemble, tmpdir): + """Test saving and loading ensembles.""" + ensemble = cln_example_ensemble + ensemble.forces = np.zeros((ensemble.n_frames, ensemble.n_atoms, 3)) + ensemble.custom_field = np.zeros((5, ensemble.n_atoms, 3)) + root_dir = Path(__file__).parent.parent.parent + cln_path = ( + root_dir / "examples" / "example_dataset_files" / "cln_folded.pdb" + ) + cln_native_structure = md.load_pdb(cln_path) + feat = Featurizer() + feat.add_rmsd(ensemble, reference_structure=cln_native_structure) + with h5py.File(tmpdir / "test.h5", "w") as f: + ensemble.write_to_hdf5(f, name="example_ensemble") + + with h5py.File(tmpdir / "test.h5", "r") as f: + group = f["example_ensemble"] + ensemble2 = Ensemble.from_hdf5(group) + + assert ensemble2.n_frames == ensemble.n_frames + assert np.allclose(ensemble2.coords, ensemble.coords) + assert np.allclose(ensemble2.forces, ensemble.forces) + assert np.allclose(ensemble2.custom_field, ensemble.custom_field) + assert set(ensemble2.list_quantities()) == set(ensemble.list_quantities()) + assert np.allclose( + ensemble2.get_quantity("rmsd").raw_value, + ensemble.get_quantity("rmsd").raw_value, + ) + assert ensemble2.top == ensemble.top + assert ensemble2.name == ensemble.name + assert ensemble2["custom_field"].unit == "dimensionless" + + +def test_featurized_ca_dist_ensemble_to_h5(cln_example_ensemble, tmpdir): + """Test saving and loading ensembles.""" + ensemble = cln_example_ensemble + ensemble.forces = np.zeros((ensemble.n_frames, ensemble.n_atoms, 3)) + ensemble.custom_field = np.zeros((5, ensemble.n_atoms, 3)) + # quantities from a featurizer. + feat = Featurizer() + feat.add_ca_distances(ensemble, offset=1, subset_selection="name CA") + with h5py.File(tmpdir / "test.h5", "w") as f: + ensemble.write_to_hdf5(f, name="example_ensemble") + + with h5py.File(tmpdir / "test.h5", "r") as f: + group = f["example_ensemble"] + ensemble2 = Ensemble.from_hdf5(group) + + assert ensemble2.n_frames == ensemble.n_frames + assert np.allclose(ensemble2.coords, ensemble.coords) + assert np.allclose(ensemble2.forces, ensemble.forces) + assert np.allclose(ensemble2.custom_field, ensemble.custom_field) + assert set(ensemble2.list_quantities()) == set(ensemble.list_quantities()) + assert np.allclose( + ensemble2.get_quantity("ca_distances").raw_value, + ensemble.get_quantity("ca_distances").raw_value, + ) + assert ensemble2.top == ensemble.top + assert ensemble2.name == ensemble.name + assert ensemble2["custom_field"].unit == "dimensionless" + + def test_ensemble_with_extra_builtin_quantity(example_ensemble): unit_system = UnitSystem( "Angstrom", diff --git a/proteka/metrics/featurizer.py b/proteka/metrics/featurizer.py index 09171b9..ebc8dc2 100644 --- a/proteka/metrics/featurizer.py +++ b/proteka/metrics/featurizer.py @@ -6,7 +6,6 @@ import warnings import numpy as np import mdtraj as md -from typing import Dict, Optional, List, Tuple from ..dataset import Ensemble from ..quantity import Quantity from ..dataset.top_utils import top2json @@ -484,10 +483,8 @@ def add_rmsd( metadata = { "feature": "rmsd", - "reference_structure": { - "coords": ref_coords, - "top": ref_top, - }, + "reference_structure_coords": ref_coords, + "reference_structure_top": ref_top, "atom_selection": atom_selection, } @@ -881,11 +878,9 @@ def add_fraction_native_contacts( metadata = { "feature": "fraction_native_contacts", - "reference_structure": { - "coords": native_coords, - "top": native_top, - "atomistic": use_atomistic_reference, - }, + "reference_structure_coords": native_coords, + "reference_structure_top": top2json(native_top), + "reference_structure_is_atomistic": use_atomistic_reference, "atom_selection": atom_selection, "beta": beta, "lam": lam, @@ -1132,7 +1127,8 @@ def add_temporary_feature( @staticmethod def _reference_structure_equality( input_structure: md.Trajectory, - serialized_structure: Dict, + stored_coords: Union[List, np.ndarray], + stored_top: str, ) -> bool: """Helper method for testing reference structure serialized equality for RMSD recomputation @@ -1140,8 +1136,10 @@ def _reference_structure_equality( ---------- input_structures: input MDTraj single frame Trajectory for proposed RMSD calculations - serialized_structure: - Serialized reference structure + stored_coords: + Saved reference structure coordinates + stored_top: + Saved reference structure topology Returns ------- @@ -1150,14 +1148,13 @@ def _reference_structure_equality( the same, True is returned. Else, False is returned. """ - ref_coords = input_structure.xyz.tolist() + ref_coords = np.array(input_structure.xyz) + stored_coords = np.array(stored_coords) ref_top = top2json(input_structure.topology) - stored_coords = serialized_structure["coords"] - stored_top = serialized_structure["top"] equals = [] equals.append(stored_top == ref_top) - equals.append(stored_coords == ref_coords) + equals.append(np.allclose(stored_coords, ref_coords)) return all(equals) @staticmethod @@ -1238,7 +1235,8 @@ def get_feature( reference_structure = args[0] if not Featurizer._reference_structure_equality( reference_structure, - ensemble[feature].metadata["reference_structure"], + ensemble[feature].metadata["reference_structure_coords"], + ensemble[feature].metadata["reference_structure_top"], ): recompute = True @@ -1251,7 +1249,10 @@ def get_feature( reference_structure = kwargs[key] if not Featurizer._reference_structure_equality( reference_structure, - ensemble[feature].metadata["reference_structure"], + ensemble[feature].metadata[ + "reference_structure_coords" + ], + ensemble[feature].metadata["reference_structure_top"], ): recompute = True break diff --git a/proteka/quantity/meta_array.py b/proteka/quantity/meta_array.py index 83528c2..8357a90 100644 --- a/proteka/quantity/meta_array.py +++ b/proteka/quantity/meta_array.py @@ -157,12 +157,17 @@ def write_to_hdf5(self, h5_node, name=None): None """ + # local hack for storing strings + value_to_save = self._value + if value_to_save.dtype == np.dtype("O"): + value_to_save = np.asarray(str(value_to_save[()]), dtype="O") + def overwrite(dataset_node, input_pattern="h5_node"): # overwrite an existing Dataset in HDF5 file warn( f"Input `{input_pattern}` correponds to existing Dataset with name {dataset_node.name}. Overwritting..." ) - dataset_node[...] = self._value + dataset_node[...] = value_to_save for k, v in self.metadata.items(): dataset_node.attrs[k] = v @@ -182,7 +187,7 @@ def overwrite(dataset_node, input_pattern="h5_node"): overwrite(h5_node[name], "h5_node[name]") else: # create a new Dataset under h5_node - h5_node[name] = self._value + h5_node[name] = value_to_save for k, v in self.metadata.items(): h5_node[name].attrs[k] = v else: