diff --git a/apax/nodes/model.py b/apax/nodes/model.py index 17bbe69a..1a68eb61 100644 --- a/apax/nodes/model.py +++ b/apax/nodes/model.py @@ -2,6 +2,8 @@ import pathlib import typing as t +import ase.calculators +import ase.calculators.singlepoint import ase.io import numpy as np import pandas as pd @@ -95,8 +97,57 @@ def run(self): """Primary method to run which executes all steps of the model training""" if not self.state.restarted: - ase.io.write(self.train_data_file.as_posix(), self.data) - ase.io.write(self.validation_data_file.as_posix(), self.validation_data) + common_keys = set(self.data[0].calc.results.keys()) + for atoms in self.data[1:]: + common_keys &= set(atoms.calc.results.keys()) + for atoms in self.validation_data: + common_keys &= set(atoms.calc.results.keys()) + log.warning(f"common keys = {common_keys}") + + new_frames = [] + for atoms in self.data: + results = {} + for key in common_keys: + results[key] = atoms.calc.results[key] + + symbols = atoms.get_chemical_symbols() + pbc = atoms.get_pbc() + positions = atoms.get_positions() + cell = atoms.get_cell() + + new_atoms = ase.Atoms( + symbols=symbols, positions=positions, cell=cell, pbc=pbc + ) + + calc = ase.calculators.singlepoint.SinglePointCalculator( + new_atoms, **results + ) + new_atoms.calc = calc + new_frames.append(new_atoms) + + new_val_frames = [] + for atoms in self.validation_data: + results = {} + for key in common_keys: + results[key] = atoms.calc.results[key] + + symbols = atoms.get_chemical_symbols() + pbc = atoms.get_pbc() + positions = atoms.get_positions() + cell = atoms.get_cell() + + new_val_atoms = ase.Atoms( + symbols=symbols, positions=positions, cell=cell, pbc=pbc + ) + + calc = ase.calculators.singlepoint.SinglePointCalculator( + new_val_atoms, **results + ) + new_val_atoms.calc = calc + new_val_frames.append(new_val_atoms) + + ase.io.write(self.train_data_file.as_posix(), new_frames) + ase.io.write(self.validation_data_file.as_posix(), new_val_frames) csv_path = self.model_directory / "log.csv" if self.state.restarted and csv_path.is_file(): diff --git a/apax/utils/convert.py b/apax/utils/convert.py index b5c89802..f947997f 100644 --- a/apax/utils/convert.py +++ b/apax/utils/convert.py @@ -1,3 +1,5 @@ +import logging + import jax.numpy as jnp import numpy as np from ase import Atoms @@ -5,6 +7,8 @@ from apax.utils.jax_md_reduced import space +log = logging.getLogger(__name__) + DTYPE = np.float64 unit_dict = { "Ang": Ang, @@ -147,18 +151,34 @@ def atoms_to_labels( """ labels = { - "forces": [], + # "forces": [], "energy": [], - "stress": [], + # "stress": [], } + + common_keys = set(atoms_list[0].calc.results.keys()) + for atoms in atoms_list[1:]: + common_keys &= set(atoms.calc.results.keys()) + log.info(f"Labels found in the dataset: {common_keys}") + property_names = [p[0] for p in additional_properties] for key in property_names: if key not in labels.keys(): placeholder = {key: []} labels.update(placeholder) + for key in labels.keys(): + if key not in common_keys: + log.error(f"Label {key} missing at least in one structure") + + for key in common_keys: + if key not in labels.keys(): + placeholder = {key: []} + labels.update(placeholder) + for atoms in atoms_list: - for key, val in atoms.calc.results.items(): + for key in common_keys: + val = atoms.calc.results[key] if key == "forces": labels[key].append(val * unit_dict[energy_unit] / unit_dict[pos_unit]) elif key == "energy": @@ -168,7 +188,7 @@ def atoms_to_labels( stress = atoms.get_stress(voigt=False) * factor labels[key].append(stress * atoms.cell.volume) elif key in property_names: - labels[key].append(atoms.calc.results[key]) + labels[key].append(val) labels = prune_dict(labels) return labels