Skip to content

Commit 4c5e7bf

Browse files
authored
NPT support (#1972)
add NPT support + log T, P, KE at each frame as well.
1 parent b5cec2e commit 4c5e7bf

4 files changed

Lines changed: 106 additions & 1 deletion

File tree

src/fairchem/core/components/calculate/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ._single.relaxation_runner import RelaxationRunner
2626
from ._single.singlepoint_runner import SinglePointRunner
2727
from .simulation_tools.thermostats import (
28+
BerendsenNPT,
2829
BussiThermostat,
2930
LangevinThermostat,
3031
NoseHooverNVT,
@@ -37,6 +38,7 @@
3738
"AdsorbMLRunner",
3839
"AdsorptionRunner",
3940
"AdsorptionSinglePointRunner",
41+
"BerendsenNPT",
4042
"BussiThermostat",
4143
"ElasticityRunner",
4244
"KappaRunner",

src/fairchem/core/components/calculate/simulation_tools/thermostats.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from ase.md.bussi import Bussi
1717
from ase.md.langevin import Langevin
1818
from ase.md.nose_hoover_chain import NoseHooverChainNVT
19+
from ase.md.nptberendsen import NPTBerendsen
1920
from ase.md.verlet import VelocityVerlet
2021
from monty.json import jsanitize
2122

@@ -210,3 +211,33 @@ def restore_state(self, dyn: MolecularDynamics, state: dict[str, Any]) -> None:
210211
rng["cached_gaussian"],
211212
)
212213
)
214+
215+
216+
@dataclass
217+
class BerendsenNPT(Thermostat):
218+
"""
219+
Berendsen NPT thermostat/barostat for constant pressure simulations.
220+
"""
221+
222+
temperature_K: float
223+
pressure_bar: float = 1.0
224+
taut_fs: float = 5.0
225+
taup_fs: float = 500.0
226+
compressibility_bar: float = 5e-7
227+
228+
def build(self, atoms: Atoms, timestep_fs: float) -> MolecularDynamics:
229+
return NPTBerendsen(
230+
atoms=atoms,
231+
timestep=timestep_fs * units.fs,
232+
temperature_K=self.temperature_K,
233+
pressure_au=self.pressure_bar * units.bar,
234+
taut=self.taut_fs * units.fs,
235+
taup=self.taup_fs * units.fs,
236+
compressibility_au=self.compressibility_bar / units.bar,
237+
)
238+
239+
def save_state(self, dyn: MolecularDynamics) -> dict[str, Any]:
240+
return {"class_name": "BerendsenNPT"}
241+
242+
def restore_state(self, dyn: MolecularDynamics, state: dict[str, Any]) -> None:
243+
pass

src/fairchem/core/components/calculate/simulation_tools/trajectory.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from pathlib import Path
1212
from typing import TYPE_CHECKING
1313

14+
import ase.units
1415
import numpy as np
1516
import pyarrow as pa
1617
import pyarrow.parquet as pq
@@ -42,6 +43,9 @@ class TrajectoryFrame:
4243
energy: float | None = None
4344
forces: np.ndarray | None = None # (N, 3)
4445
stress: np.ndarray | None = None # (6,) Voigt notation
46+
temperature: float | None = None # Kelvin
47+
kinetic_energy: float | None = None # eV
48+
pressure: float | None = None # bar
4549
sid: str | int | None = None
4650

4751
def to_dict(self) -> dict:
@@ -67,6 +71,12 @@ def to_dict(self) -> dict:
6771
d["forces"] = self.forces.tolist()
6872
if self.stress is not None:
6973
d["stress"] = self.stress.tolist()
74+
if self.temperature is not None:
75+
d["temperature"] = self.temperature
76+
if self.kinetic_energy is not None:
77+
d["kinetic_energy"] = self.kinetic_energy
78+
if self.pressure is not None:
79+
d["pressure"] = self.pressure
7080
return d
7181

7282
@classmethod
@@ -104,6 +114,22 @@ def from_atoms(
104114
except (PropertyNotImplementedError, RuntimeError):
105115
velocities = None
106116

117+
try:
118+
temperature = atoms.get_temperature()
119+
except Exception:
120+
temperature = None
121+
122+
try:
123+
kinetic_energy = atoms.get_kinetic_energy()
124+
except Exception:
125+
kinetic_energy = None
126+
127+
# Pressure from stress: P = -trace(stress)/3, converted to bar
128+
if stress is not None:
129+
pressure = -stress[:3].mean() / ase.units.bar
130+
else:
131+
pressure = None
132+
107133
return cls(
108134
step=step,
109135
time=time,
@@ -115,6 +141,9 @@ def from_atoms(
115141
energy=energy,
116142
forces=forces,
117143
stress=stress,
144+
temperature=temperature,
145+
kinetic_energy=kinetic_energy,
146+
pressure=pressure,
118147
sid=atoms.info.get("sid"),
119148
)
120149

tests/core/components/test_md_runner.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ase.md.verlet import VelocityVerlet
2727

2828
from fairchem.core.components.calculate import (
29+
BerendsenNPT,
2930
BussiThermostat,
3031
LangevinThermostat,
3132
MDRunner,
@@ -148,8 +149,15 @@ def test_md_correctness_vs_ase(self, cu_atoms, results_dir):
148149
NoseHooverNVT(temperature_K=300.0, tdamp_fs=25.0),
149150
BussiThermostat(temperature_K=300.0, taut_fs=25.0),
150151
LangevinThermostat(temperature_K=300.0, friction_per_fs=0.01),
152+
BerendsenNPT(
153+
temperature_K=300.0,
154+
pressure_bar=1.0,
155+
taut_fs=500.0,
156+
taup_fs=1000.0,
157+
compressibility_bar=1.0 / 140e9,
158+
),
151159
],
152-
ids=["VelocityVerlet", "NoseHoover", "Bussi", "Langevin"],
160+
ids=["VelocityVerlet", "NoseHoover", "Bussi", "Langevin", "BerendsenNPT"],
153161
)
154162
def test_checkpoint_resume(self, cu_atoms, results_dir, thermostat):
155163
"""
@@ -333,3 +341,38 @@ def test_stopfair_graceful_stop(self, cu_atoms, results_dir):
333341

334342
traj_df = pd.read_parquet(results["trajectory_file"])
335343
assert list(traj_df["step"]) == [0, 10, 20]
344+
345+
def test_npt_cell_changes(self, cu_atoms, results_dir):
346+
"""
347+
Verify that NPT simulation changes the cell volume.
348+
"""
349+
md_results_dir = results_dir / "results"
350+
md_results_dir.mkdir()
351+
352+
atoms = cu_atoms.copy()
353+
initial_volume = atoms.get_volume()
354+
355+
# Use a large pressure to drive a noticeable volume change
356+
thermostat = BerendsenNPT(
357+
temperature_K=300.0,
358+
pressure_bar=1e5,
359+
taut_fs=100.0,
360+
taup_fs=100.0,
361+
compressibility_bar=1.0 / 140e9,
362+
)
363+
364+
runner = MDRunner(
365+
calculator=EMT(),
366+
atoms=atoms,
367+
thermostat=thermostat,
368+
timestep_fs=1.0,
369+
steps=200,
370+
trajectory_interval=50,
371+
log_interval=50,
372+
trajectory_writer=partial(ParquetTrajectoryWriter, flush_interval=1000),
373+
)
374+
runner._job_config = _create_mock_job_config(str(md_results_dir))
375+
runner.calculate(job_num=0, num_jobs=1)
376+
377+
final_volume = atoms.get_volume()
378+
assert initial_volume != pytest.approx(final_volume, rel=1e-6)

0 commit comments

Comments
 (0)