Skip to content

Commit a1a4dab

Browse files
committed
[skip ci] Add support for Qbox as backend
1 parent c643a78 commit a1a4dab

File tree

3 files changed

+250
-3
lines changed

3 files changed

+250
-3
lines changed

pysages/backends/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# SPDX-License-Identifier: MIT
22
# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES
33

4-
from .contexts import JaxMDContext, JaxMDContextState # noqa: E402, F401
4+
from .contexts import ( # noqa: E402, F401
5+
JaxMDContext,
6+
JaxMDContextState,
7+
QboxContextGenerator,
8+
)
59
from .core import SamplingContext, supported_backends # noqa: E402, F401

pysages/backends/core.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from importlib import import_module
55

6-
from pysages.backends.contexts import JaxMDContext
6+
from pysages.backends.contexts import JaxMDContext, QboxContext
77
from pysages.typing import Callable, Optional
88

99

@@ -38,6 +38,8 @@ def __init__(
3838
self._backend_name = "lammps"
3939
elif module_name.startswith("simtk.openmm") or module_name.startswith("openmm"):
4040
self._backend_name = "openmm"
41+
elif isinstance(context, QboxContext):
42+
self._backend_name = "qbox"
4143

4244
if self._backend_name is None:
4345
backends = ", ".join(supported_backends())
@@ -74,4 +76,4 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
7476

7577

7678
def supported_backends():
77-
return ("ase", "hoomd", "jax-md", "lammps", "openmm")
79+
return ("ase", "hoomd", "jax-md", "lammps", "openmm", "qbox")

pysages/backends/qbox.py

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
# SPDX-License-Identifier: MIT
2+
# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES
3+
4+
"""
5+
This module defines the Sampler class, which is a LAMMPS fix that enables any PySAGES
6+
SamplingMethod to be hooked to a LAMMPS simulation instance.
7+
"""
8+
9+
from jax import jit
10+
from jax import numpy as np
11+
from plum import Val, type_parameter
12+
13+
from pysages.backends.core import SamplingContext
14+
from pysages.backends.snapshot import (
15+
Box,
16+
HelperMethods,
17+
Snapshot,
18+
SnapshotMethods,
19+
build_data_querier,
20+
)
21+
from pysages.typing import Callable, Optional
22+
from pysages.utils import dispatch, last, parse_array
23+
24+
kConversionFactor = 1836.152674 # to convert mv²/2 to energy units
25+
26+
27+
class Sampler:
28+
"""
29+
Allows performing enhanced sampling simulations with Qbox as a backend.
30+
31+
Parameters
32+
----------
33+
34+
context: QboxContext
35+
Contains a running instance of a Qbox simulation to which the PySAGES sampling
36+
machinery will be hooked.
37+
38+
sampling_method: SamplingMethod
39+
The sampling method to be used.
40+
41+
callbacks: Optional[Callback]
42+
Some methods define callbacks for logging, but it can also be user-defined.
43+
"""
44+
45+
def __init__(self, context, sampling_method, callback: Optional[Callable]):
46+
self.context = context
47+
self.callback = callback
48+
49+
self.snapshot = self.take_snapshot()
50+
helpers, bias = build_helpers(context, sampling_method)
51+
_, initialize, method_update = sampling_method.build(self.snapshot, helpers)
52+
53+
# Initialize external forces for each atom
54+
for name in atom_property(context, Val("name")):
55+
# Initialize with zero force
56+
cmd = f"extforce define atomic {name} {name} 0.0 0.0 0.0"
57+
context.process_input(cmd)
58+
59+
self.state = initialize()
60+
self._update_box = lambda: self.snapshot.box
61+
self._method_update = method_update
62+
self._bias = bias
63+
64+
def _pack_snapshot(self, masses, ids, box, dt):
65+
"""Returns the dynamic properties of the system."""
66+
positions = atom_property(self.context, Val("position"))
67+
velocities = atom_property(self.context, Val("velocity"))
68+
forces = atom_property(self.context, Val("force"))
69+
return Snapshot(positions, (velocities, masses), forces, ids, None, box, dt)
70+
71+
def _update_snapshot(self):
72+
"""Updates the snapshot with the latest properties from Qbox."""
73+
snapshot = self.snapshot
74+
_, masses = snapshot.vel_mass
75+
return self._pack_snapshot(masses, snapshot.ids, self._update_box(), snapshot.dt)
76+
77+
def restore(self, prev_snapshot):
78+
"""Replaces this sampler's snapshot with `prev_snapshot`."""
79+
context = self.context
80+
names = atom_property(context, Val("name"))
81+
positions = prev_snapshot.positions
82+
velocities, _ = prev_snapshot.vel_mass
83+
84+
for name, x, v in zip(names, positions, velocities):
85+
cmd = f"move {name} to {x[0]} {x[1]} {x[2]} {v[0]} {v[1]} {v[2]}"
86+
context.process_input(cmd)
87+
88+
# Recompute ground-state energies and forces.
89+
# NOTE: Check in the future how to use Qbox `load` and `save` commands to also
90+
# include the electronic wave function data.
91+
context.process_input(f"run 0 {context.niter} {context.nitscf}")
92+
self.snapshot = self._update_snapshot()
93+
94+
def take_snapshot(self):
95+
"""Returns a copy of the current snapshot of the system."""
96+
masses = atom_property(self.context, Val("mass"))
97+
ids = np.arange(len(masses))
98+
snapshot_box = Box(*box(self.context))
99+
dt = timestep(self.context)
100+
return self._pack_snapshot(masses, ids, snapshot_box, dt)
101+
102+
def update(self, timestep):
103+
"""Update the sampling method state and apply bias."""
104+
self.snapshot = self._update_snapshot()
105+
self.state = self._method_update(self.snapshot, self.state)
106+
self._bias(self.snapshot, self.state)
107+
if self.callback:
108+
self.callback(self.snapshot, self.state, timestep)
109+
110+
def run(self, nsteps: int):
111+
"""Run the Qbox simulation for nsteps."""
112+
cmd = f"run 1 {self.context.niter} {self.context.nitscf}"
113+
for step in range(nsteps):
114+
# Send run command to Qbox for a single step
115+
self.context.process_input(cmd)
116+
# Update sampling method state after each step
117+
self.update(step)
118+
119+
120+
def build_snapshot_methods(sampling_method):
121+
"""
122+
Builds methods for retrieving snapshot properties in a format useful for collective
123+
variable calculations.
124+
"""
125+
126+
def positions(snapshot):
127+
return snapshot.positions
128+
129+
def indices(snapshot):
130+
return snapshot.ids
131+
132+
def momenta(snapshot):
133+
V, M = snapshot.vel_mass
134+
return (M * V).flatten()
135+
136+
def masses(snapshot):
137+
_, M = snapshot.vel_mass
138+
return M
139+
140+
return SnapshotMethods(positions, indices, jit(momenta), masses)
141+
142+
143+
def build_helpers(context, sampling_method):
144+
"""
145+
Builds helper methods used for restoring snapshots and biasing a simulation.
146+
"""
147+
# Precompute atom names since they won't change
148+
names = atom_property(context, Val("name"))
149+
150+
def to_force_units(x):
151+
return kConversionFactor * x
152+
153+
def extforce_cmd(name, force):
154+
return f"extforce set {name} {force[0]} {force[1]} {force[2]}"
155+
156+
def bias(snapshot, state):
157+
"""Adds the computed bias to the forces using Qbox's extforce command."""
158+
if state.bias is None:
159+
return
160+
# Generate and send all extforce commands
161+
context.process_input(extforce_cmd(name, force) for name, force in zip(names, state.bias))
162+
163+
snapshot_methods = build_snapshot_methods(sampling_method)
164+
flags = sampling_method.snapshot_flags
165+
helpers = HelperMethods(build_data_querier(snapshot_methods, flags), lambda: 3, to_force_units)
166+
167+
return helpers, bias
168+
169+
170+
@dispatch
171+
def atom_property(context, prop: Val):
172+
return atom_property(context, *specialize(context, prop))
173+
174+
175+
@dispatch
176+
def atom_property(context, xml_tag, extract, gather):
177+
atomset = last(context.state.iter("atomset"))
178+
if atomset is None:
179+
context.process_input("run 0")
180+
atomset = last(context.state.iter("atomset"))
181+
return gather(extract(elem) for elem in atomset.iter(xml_tag))
182+
183+
184+
@dispatch
185+
def specialize(context, prop: Val["name"]): # noqa: F821
186+
return (
187+
"atom", # xml_tag
188+
(lambda s: s.attrib["name"]), # extract
189+
list, # gather
190+
)
191+
192+
193+
@dispatch
194+
def specialize(context, prop: Val["mass"]): # noqa: F821
195+
return (
196+
"atom", # xml_tag
197+
(lambda s: context.species_masses[s.attrib["species"]]), # extract
198+
(lambda d: np.array(list(d)).reshape(-1, 1)), # gather
199+
)
200+
201+
202+
@dispatch
203+
def specialize(context, prop: Val):
204+
return (
205+
type_parameter(prop), # xml_tag
206+
(lambda s: s.text), # extract
207+
(lambda d: parse_array(" ".join(d))), # gather
208+
)
209+
210+
211+
def box(context):
212+
elem = last(context.state.iter("unit_cell"))
213+
if elem is None:
214+
context.process_input("print cell")
215+
elem = context.state.find("unit_cell")
216+
cell_vecs = " ".join(elem.attrib.values())
217+
H = parse_array(cell_vecs, transpose=True)
218+
origin = np.array([0.0, 0.0, 0.0])
219+
return Box(H, origin)
220+
221+
222+
def timestep(context):
223+
context.process_input("print dt")
224+
elem = context.state.find("cmd")
225+
return float(elem.tail.strip("\ndt= "))
226+
227+
228+
def bind(sampling_context: SamplingContext, callback: Optional[Callable], **kwargs):
229+
"""
230+
Sets up and returns a Sampler which enables performing enhanced sampling simulations.
231+
232+
This function takes a `sampling_context` that has its context attribute as an instance
233+
of a `QboxContext,` and creates a `Sampler` object that connects the PySAGES
234+
sampling method to the Qbox simulation. It also modifies the `sampling_context`'s
235+
`view` and `run` attributes to call the Qbox `run` command.
236+
"""
237+
context = sampling_context.context
238+
sampler = Sampler(context, sampling_context.method, callback)
239+
sampling_context.run = sampler.run
240+
241+
return sampler

0 commit comments

Comments
 (0)