Skip to content

Commit 6f9146c

Browse files
committed
[skip ci] Add support for Qbox as backend
1 parent 60ade92 commit 6f9146c

File tree

3 files changed

+255
-3
lines changed

3 files changed

+255
-3
lines changed

pysages/backends/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# SPDX-License-Identifier: MIT
22
# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES
33

4-
from .contexts import JaxMDContext, JaxMDContextState # noqa: E402, F401
4+
from .contexts import ( # noqa: E402, F401
5+
JaxMDContext,
6+
JaxMDContextState,
7+
QboxContextGenerator,
8+
)
59
from .core import SamplingContext, supported_backends # noqa: E402, F401

pysages/backends/core.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from importlib import import_module
55

6-
from pysages.backends.contexts import JaxMDContext
6+
from pysages.backends.contexts import JaxMDContext, QboxContext
77
from pysages.typing import Callable, Optional
88

99

@@ -38,6 +38,8 @@ def __init__(
3838
self._backend_name = "lammps"
3939
elif module_name.startswith("simtk.openmm") or module_name.startswith("openmm"):
4040
self._backend_name = "openmm"
41+
elif isinstance(context, QboxContext):
42+
self._backend_name = "qbox"
4143

4244
if self._backend_name is None:
4345
backends = ", ".join(supported_backends())
@@ -74,4 +76,4 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
7476

7577

7678
def supported_backends():
77-
return ("ase", "hoomd", "jax-md", "lammps", "openmm")
79+
return ("ase", "hoomd", "jax-md", "lammps", "openmm", "qbox")

pysages/backends/qbox.py

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
# SPDX-License-Identifier: MIT
2+
# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES
3+
4+
"""
5+
This module defines the Sampler class, which is a LAMMPS fix that enables any PySAGES
6+
SamplingMethod to be hooked to a LAMMPS simulation instance.
7+
"""
8+
9+
from jax import jit
10+
from jax import numpy as np
11+
from plum import Val, type_parameter
12+
13+
from pysages.backends.core import SamplingContext
14+
from pysages.backends.snapshot import (
15+
Box,
16+
HelperMethods,
17+
Snapshot,
18+
SnapshotMethods,
19+
build_data_querier,
20+
)
21+
from pysages.typing import Callable, Optional
22+
from pysages.utils import dispatch, last, parse_array
23+
24+
kConversionFactor = 1836.152674 # to convert mv²/2 to energy units
25+
26+
27+
class Sampler:
28+
"""
29+
Allows performing enhanced sampling simulations with Qbox as a backend.
30+
31+
Parameters
32+
----------
33+
34+
context: QboxContext
35+
Contains a running instance of a Qbox simulation to which the PySAGES sampling
36+
machinery will be hooked.
37+
38+
sampling_method: SamplingMethod
39+
The sampling method to be used.
40+
41+
callbacks: Optional[Callback]
42+
Some methods define callbacks for logging, but it can also be user-defined.
43+
"""
44+
45+
def __init__(self, context, sampling_method, callback: Optional[Callable]):
46+
self.context = context
47+
self.callback = callback
48+
49+
self.snapshot = self.take_snapshot()
50+
helpers, bias, atom_names, cv_indices = build_helpers(context, sampling_method)
51+
_, initialize, method_update = sampling_method.build(self.snapshot, helpers)
52+
53+
# Initialize external forces for each atom
54+
for i in cv_indices:
55+
name = atom_names[i]
56+
# Initialize with zero force
57+
cmd = f"extforce define atomic {name} {name} 0.0 0.0 0.0"
58+
context.process_input(cmd)
59+
60+
self.state = initialize()
61+
self._update_box = lambda: self.snapshot.box
62+
self._method_update = method_update
63+
self._bias = bias
64+
65+
def _pack_snapshot(self, masses, ids, box, dt):
66+
"""Returns the dynamic properties of the system."""
67+
positions = atom_property(self.context, Val("position"))
68+
velocities = atom_property(self.context, Val("velocity"))
69+
forces = atom_property(self.context, Val("force"))
70+
return Snapshot(positions, (velocities, masses), forces, ids, None, box, dt)
71+
72+
def _update_snapshot(self):
73+
"""Updates the snapshot with the latest properties from Qbox."""
74+
snapshot = self.snapshot
75+
_, masses = snapshot.vel_mass
76+
return self._pack_snapshot(masses, snapshot.ids, self._update_box(), snapshot.dt)
77+
78+
def restore(self, prev_snapshot):
79+
"""Replaces this sampler's snapshot with `prev_snapshot`."""
80+
context = self.context
81+
names = atom_property(context, Val("name"))
82+
positions = prev_snapshot.positions
83+
velocities, _ = prev_snapshot.vel_mass
84+
85+
for name, x, v in zip(names, positions, velocities):
86+
cmd = f"move {name} to {x[0]} {x[1]} {x[2]} {v[0]} {v[1]} {v[2]}"
87+
context.process_input(cmd)
88+
89+
# Recompute ground-state energies and forces.
90+
# NOTE: Check in the future how to use Qbox `load` and `save` commands to also
91+
# include the electronic wave function data.
92+
context.process_input(f"run 0 {context.niter} {context.nitscf}")
93+
self.snapshot = self._update_snapshot()
94+
95+
def take_snapshot(self):
96+
"""Returns a copy of the current snapshot of the system."""
97+
masses = atom_property(self.context, Val("mass"))
98+
ids = np.arange(len(masses))
99+
snapshot_box = Box(*box(self.context))
100+
dt = timestep(self.context)
101+
return self._pack_snapshot(masses, ids, snapshot_box, dt)
102+
103+
def update(self, timestep):
104+
"""Update the sampling method state and apply bias."""
105+
self.snapshot = self._update_snapshot()
106+
self.state = self._method_update(self.snapshot, self.state)
107+
self._bias(self.snapshot, self.state)
108+
if self.callback:
109+
self.callback(self.snapshot, self.state, timestep)
110+
111+
def run(self, nsteps: int):
112+
"""Run the Qbox simulation for nsteps."""
113+
cmd = f"run {self.context.niter} {self.context.nitscf} {self.context.nite}"
114+
for step in range(nsteps):
115+
# Send run command to Qbox for a single step
116+
self.context.process_input(cmd)
117+
# Update sampling method state after each step
118+
self.update(step)
119+
120+
121+
def build_snapshot_methods(sampling_method):
122+
"""
123+
Builds methods for retrieving snapshot properties in a format useful for collective
124+
variable calculations.
125+
"""
126+
127+
def positions(snapshot):
128+
return snapshot.positions
129+
130+
def indices(snapshot):
131+
return snapshot.ids
132+
133+
def momenta(snapshot):
134+
V, M = snapshot.vel_mass
135+
return (M * V).flatten()
136+
137+
def masses(snapshot):
138+
_, M = snapshot.vel_mass
139+
return M
140+
141+
return SnapshotMethods(positions, indices, jit(momenta), masses)
142+
143+
144+
def build_helpers(context, sampling_method):
145+
"""
146+
Builds helper methods used for restoring snapshots and biasing a simulation.
147+
"""
148+
# Precompute atom names since they won't change
149+
atom_names = atom_property(context, Val("name"))
150+
151+
cv_indices = set()
152+
for cv in sampling_method.cvs:
153+
cv_indices.update(n.item() for n in cv.indices)
154+
155+
def to_force_units(x):
156+
return kConversionFactor * x
157+
158+
def extforce_cmd(name, force):
159+
return f"extforce set {name} {force[0]} {force[1]} {force[2]}"
160+
161+
def bias(snapshot, state):
162+
"""Adds the computed bias to the forces using Qbox's extforce command."""
163+
if state.bias is None:
164+
return
165+
# Generate and send all extforce commands
166+
context.process_input(extforce_cmd(atom_names[i], state.bias[i]) for i in cv_indices)
167+
168+
snapshot_methods = build_snapshot_methods(sampling_method)
169+
flags = sampling_method.snapshot_flags
170+
helpers = HelperMethods(build_data_querier(snapshot_methods, flags), lambda: 3, to_force_units)
171+
172+
return helpers, bias, atom_names, cv_indices
173+
174+
175+
@dispatch
176+
def atom_property(context, prop: Val):
177+
return atom_property(context, *specialize(context, prop))
178+
179+
180+
@dispatch
181+
def atom_property(context, xml_tag, extract, gather):
182+
atomset = last(context.state.iter("atomset"))
183+
if atomset is None:
184+
context.process_input("run 0")
185+
atomset = last(context.state.iter("atomset"))
186+
return gather(extract(elem) for elem in atomset.iter(xml_tag))
187+
188+
189+
@dispatch
190+
def specialize(context, prop: Val["name"]): # noqa: F821
191+
return (
192+
"atom", # xml_tag
193+
(lambda s: s.attrib["name"]), # extract
194+
list, # gather
195+
)
196+
197+
198+
@dispatch
199+
def specialize(context, prop: Val["mass"]): # noqa: F821
200+
return (
201+
"atom", # xml_tag
202+
(lambda s: context.species_masses[s.attrib["species"]]), # extract
203+
(lambda d: np.array(list(d)).reshape(-1, 1)), # gather
204+
)
205+
206+
207+
@dispatch
208+
def specialize(context, prop: Val):
209+
return (
210+
type_parameter(prop), # xml_tag
211+
(lambda s: s.text), # extract
212+
(lambda d: parse_array(" ".join(d))), # gather
213+
)
214+
215+
216+
def box(context):
217+
elem = last(context.state.iter("unit_cell"))
218+
if elem is None:
219+
context.process_input("print cell")
220+
elem = context.state.find("unit_cell")
221+
cell_vecs = " ".join(elem.attrib.values())
222+
H = parse_array(cell_vecs, transpose=True)
223+
origin = np.array([0.0, 0.0, 0.0])
224+
return Box(H, origin)
225+
226+
227+
def timestep(context):
228+
context.process_input("print dt")
229+
elem = context.state.find("cmd")
230+
return float(elem.tail.strip("\ndt= "))
231+
232+
233+
def bind(sampling_context: SamplingContext, callback: Optional[Callable], **kwargs):
234+
"""
235+
Sets up and returns a Sampler which enables performing enhanced sampling simulations.
236+
237+
This function takes a `sampling_context` that has its context attribute as an instance
238+
of a `QboxContext,` and creates a `Sampler` object that connects the PySAGES
239+
sampling method to the Qbox simulation. It also modifies the `sampling_context`'s
240+
`view` and `run` attributes to call the Qbox `run` command.
241+
"""
242+
context = sampling_context.context
243+
sampler = Sampler(context, sampling_context.method, callback)
244+
sampling_context.run = sampler.run
245+
246+
return sampler

0 commit comments

Comments
 (0)