Skip to content

Commit be31895

Browse files
committed
new integrator
1 parent 213ab17 commit be31895

File tree

3 files changed

+280
-25
lines changed

3 files changed

+280
-25
lines changed

blues/integrators.py

Lines changed: 229 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import openmm
22
from openmmtools.integrators import AlchemicalNonequilibriumLangevinIntegrator
33
import logging
4-
4+
from openmm import unit
55
logger = logging.getLogger(__name__)
66
# Energy unit used by OpenMM unit system
77
_OPENMM_ENERGY_UNIT = openmm.unit.kilojoules_per_mole
@@ -99,7 +99,7 @@ class AlchemicalExternalLangevinIntegrator(AlchemicalNonequilibriumLangevinInteg
9999

100100
def __init__(self,
101101
alchemical_functions,
102-
splitting="R V O H O V R",
102+
splitting="H V R O V R H",
103103
temperature=298.0 * openmm.unit.kelvin,
104104
collision_rate=1.0 / openmm.unit.picoseconds,
105105
timestep=1.0 * openmm.unit.femtoseconds,
@@ -135,17 +135,21 @@ def __init__(self,
135135
self.addGlobalVariable("prop", 1)
136136
self.addGlobalVariable("prop_lambda_min", self._prop_lambda[0])
137137
self.addGlobalVariable("prop_lambda_max", self._prop_lambda[1])
138+
self.addGlobalVariable("debug_work", 0.0)
139+
self.addGlobalVariable("work_0_to_05", 0.0)
140+
self.addGlobalVariable("work_05_to_1", 0.0)
138141
# Behavior changed in https://github.com/choderalab/openmmtools/commit/7c2630050631e126d61b67f56e941de429b2d643#diff-5ce4bc8893e544833c827299a5d48b0d
139142
self._step_dispatch_table['H'] = (self._add_alchemical_perturbation_step, False)
140143
#$self._registered_step_types['H'] = (
141144
# self._add_alchemical_perturbation_step, False)
142145
self.addGlobalVariable("debug", 0)
143-
146+
logger.info(f'splitting: {splitting}')
144147
try:
145148
self.getGlobalVariableByName("shadow_work")
146149
except:
147150
self.addGlobalVariable('shadow_work', 0)
148151

152+
149153
def _get_prop_lambda(self, prop_lambda):
150154
prop_lambda_max = round(prop_lambda + 0.5, 4)
151155
prop_lambda_min = round(0.5 - prop_lambda, 4)
@@ -157,8 +161,63 @@ def _get_prop_lambda(self, prop_lambda):
157161
prop_lambda_max = -1.0
158162

159163
return prop_lambda_min, prop_lambda_max
164+
165+
166+
def _add_integrator_steps(self):
167+
"""
168+
Override the base class to insert reset steps around the integrator.
169+
"""
170+
171+
# First step: Constrain positions and velocities and reset work accumulators and alchemical integrators
172+
self.beginIfBlock('step = 0')
173+
self.addComputeGlobal("perturbed_pe", "energy")
174+
self.addComputeGlobal("unperturbed_pe", "energy")
175+
self.addConstrainPositions()
176+
self.addConstrainVelocities()
177+
self._add_reset_protocol_work_step()
178+
self._add_alchemical_reset_step()
179+
self.endBlock()
180+
181+
# Main body
182+
if self._n_steps_neq == 0:
183+
# If nsteps = 0, we need to force execution on the first step only.
184+
self.beginIfBlock('step = 0')
185+
super(AlchemicalNonequilibriumLangevinIntegrator, self)._add_integrator_steps()
186+
self.addComputeGlobal("step", "step + 1")
187+
self.endBlock()
188+
else:
189+
#call the superclass function to insert the appropriate steps, provided the step number is less than n_steps
190+
self.beginIfBlock("step < n_lambda_steps")
191+
self.addComputeGlobal("perturbed_pe", "energy")
192+
self.beginIfBlock("first_step < 1")
193+
#TODO write better test that checks that the initial work isn't gigantic
194+
self.addComputeGlobal("first_step", "1")
195+
self.addComputeGlobal("unperturbed_pe", "energy")
196+
self.endBlock()
197+
#initial iteration
198+
# Work accumulation is handled in _add_alchemical_perturbation_step()
199+
super(AlchemicalNonequilibriumLangevinIntegrator, self)._add_integrator_steps()
200+
#if more propogation steps are requested
201+
self.beginIfBlock("lambda > prop_lambda_min")
202+
self.beginIfBlock("lambda <= prop_lambda_max")
203+
204+
self.beginWhileBlock("prop < nprop")
205+
self.addComputeGlobal("prop", "prop + 1")
206+
super(AlchemicalNonequilibriumLangevinIntegrator, self)._add_integrator_steps()
207+
# Propagation steps - just do additional integration without lambda changes
208+
# The parent's integrator steps are already called above, so we don't need to call them again
209+
# This prevents double work accumulation
210+
self.endBlock()
211+
self.endBlock()
212+
self.endBlock()
213+
#ending variables to reset
214+
self.addComputeGlobal("unperturbed_pe", "energy")
215+
self.addComputeGlobal("step", "step + 1")
216+
self.addComputeGlobal("prop", "1")
217+
218+
self.endBlock()
219+
160220

161-
162221

163222
def _add_alchemical_perturbation_step(self):
164223
"""
@@ -180,30 +239,190 @@ def _add_alchemical_perturbation_step(self):
180239
# Accumulate protocol work
181240
self.addComputeGlobal("Enew", "energy")
182241
self.addComputeGlobal("protocol_work", "protocol_work + (Enew-Eold)")
242+
243+
# Track work in different phases
244+
self.beginIfBlock("lambda <= 0.5")
245+
self.addComputeGlobal("work_0_to_05", "work_0_to_05 + (Enew-Eold)")
246+
self.endBlock()
247+
248+
self.beginIfBlock("lambda > 0.5")
249+
self.addComputeGlobal("work_05_to_1", "work_05_to_1 + (Enew-Eold)")
250+
self.endBlock()
251+
252+
# Debug: Print work at move step (λ=0.5)
253+
self.beginIfBlock("abs(lambda - 0.5) < 0.001")
254+
self.addComputeGlobal("debug_work", "Enew - Eold")
255+
self.endBlock()
183256
self.endBlock()
184257

185258
def getLogAcceptanceProbability(self, context):
186259
#TODO remove context from arguments if/once ncmc_switching is changed
187260
protocol = self.getGlobalVariableByName("protocol_work")
188261
shadow = self.getGlobalVariableByName("shadow_work")
189262
logp_accept = -1.0 * (protocol + shadow) * _OPENMM_ENERGY_UNIT / self.kT
263+
logger.info(f'[WORK] protocol work: {protocol}')
264+
logger.info(f'[shadow] shadow: {shadow}')
190265

191-
# sa
192-
import numpy as np
193-
logger.info(f"Protocol work: {protocol}")
194-
logger.info(f"Shadow work: {shadow}")
195-
logger.info(f"log_accept_prob: {logp_accept}")
196-
logger.info(f"acceptance_prob: {np.exp(logp_accept)}")
266+
# Debug: Print work in different phases
267+
try:
268+
work_0_to_05 = self.getGlobalVariableByName("work_0_to_05")
269+
work_05_to_1 = self.getGlobalVariableByName("work_05_to_1")
270+
debug_work = self.getGlobalVariableByName("debug_work")
271+
logger.info(f'[DEBUG] Work 0→0.5: {work_0_to_05}')
272+
logger.info(f'[DEBUG] Work 0.5→1.0: {work_05_to_1}')
273+
logger.info(f'[DEBUG] Work at λ=0.5: {debug_work}')
274+
except:
275+
pass
197276

198277
return logp_accept
199278

200279
def reset(self):
201280
self.setGlobalVariableByName("step", 0)
202281
self.setGlobalVariableByName("lambda", 0.0)
282+
self.setGlobalVariableByName("lambda_step", 0.0)
203283
self.setGlobalVariableByName("protocol_work", 0.0)
204284
self.setGlobalVariableByName("shadow_work", 0.0)
205285
self.setGlobalVariableByName("first_step", 0)
206286
self.setGlobalVariableByName("perturbed_pe", 0.0)
207287
self.setGlobalVariableByName("unperturbed_pe", 0.0)
208288
self.setGlobalVariableByName("prop", 1)
289+
self.setGlobalVariableByName("debug_work", 0.0)
290+
self.setGlobalVariableByName("work_0_to_05", 0.0)
291+
self.setGlobalVariableByName("work_05_to_1", 0.0)
209292
super(AlchemicalExternalLangevinIntegrator, self).reset()
293+
294+
295+
#TODO: Add a class for the restrained integrator
296+
# Still need to test the restrained integrator
297+
class AlchemicalExternalRestrainedLangevinIntegrator(AlchemicalExternalLangevinIntegrator):
298+
def __init__(self,
299+
alchemical_functions,
300+
restraint_group,
301+
splitting="R V O H O V R",
302+
temperature=298.0 * unit.kelvin,
303+
collision_rate=1.0 / unit.picoseconds,
304+
timestep=1.0 * unit.femtoseconds,
305+
constraint_tolerance=1e-8,
306+
measure_shadow_work=False,
307+
measure_heat=True,
308+
nsteps_neq=0,
309+
nprop=1,
310+
prop_lambda=0.3,
311+
lambda_restraints = 'max(0, 1-(1/0.10)*abs(lambda-0.5))',
312+
#relax_steps=500, #'max(0, 1-(1/0.10)*abs(lambda-0.5))', #"3*lambda^2 - 2*lambda^3", # old: 'max(0, 1-(1/0.10)*abs(lambda-0.5))'
313+
relax_steps=50,
314+
*args, **kwargs):
315+
316+
self.lambda_restraints = lambda_restraints
317+
self.restraint_energy = "energy"+str(restraint_group)
318+
319+
super(AlchemicalExternalRestrainedLangevinIntegrator, self).__init__(
320+
alchemical_functions,
321+
splitting,
322+
temperature,
323+
collision_rate,
324+
timestep,
325+
constraint_tolerance,
326+
measure_shadow_work,
327+
measure_heat,
328+
nsteps_neq,
329+
nprop,
330+
prop_lambda,
331+
*args, **kwargs)
332+
333+
try:
334+
self.addGlobalVariable("restraint_energy", 0)
335+
except:
336+
pass
337+
logger.info(f'[LAMBDA STEPS] N_lambda_steps: {self._n_lambda_steps}')
338+
# Only declare NEW variables
339+
self.addGlobalVariable("debug_lambda", 0.0)
340+
#self.addGlobalVariable("restraint_energy", 0.0)
341+
342+
# Set existing globals from parent
343+
# self.setGlobalVariableByName("lambda_step", 0.0)
344+
# self.setGlobalVariableByName("lambda", 0.0)
345+
346+
# Optional debug
347+
self.addComputeGlobal("debug_lambda", "lambda")
348+
349+
# Now safe to use lambda_restraints in update
350+
#self.updateRestraints()
351+
352+
logger.info(f"Current nsteps_neq: {nsteps_neq}")
353+
logger.info(f'lambda_restraints selected: {self.lambda_restraints}')
354+
# compute the mid‐point slice index once
355+
mid = int(self._n_lambda_steps/2)
356+
self.addGlobalVariable("mid_step", float(mid))
357+
self.addGlobalVariable("relax_counter", 0.0)
358+
self.addGlobalVariable("relax_steps", float(relax_steps))
359+
360+
def updateRestraints(self):
361+
logger.info(f"UPDATE RESTAINTS: {self.lambda_restraints}")
362+
self.addComputeGlobal('lambda_restraints', self.lambda_restraints)
363+
364+
365+
def _add_integrator_steps(self):
366+
"""
367+
Override the base class to insert reset steps around the integrator.
368+
"""
369+
370+
# First step: Constrain positions and velocities and reset work accumulators and alchemical integrators
371+
logger.info("sams protocol")
372+
self.beginIfBlock('step = 0')
373+
self.addComputeGlobal("restraint_energy", self.restraint_energy)
374+
self.addComputeGlobal("perturbed_pe", "energy - restraint_energy")
375+
self.addComputeGlobal("unperturbed_pe", "energy - restraint_energy")
376+
self.addConstrainPositions()
377+
self.addConstrainVelocities()
378+
self._add_reset_protocol_work_step()
379+
self._add_alchemical_reset_step()
380+
self.endBlock()
381+
382+
# Main body
383+
384+
if self._n_steps_neq == 0:
385+
# If nsteps = 0, we need to force execution on the first step only.
386+
self.beginIfBlock('step = 0')
387+
super(AlchemicalNonequilibriumLangevinIntegrator, self)._add_integrator_steps()
388+
self.addComputeGlobal("step", "step + 1")
389+
self.endBlock()
390+
else:
391+
#call the superclass function to insert the appropriate steps, provided the step number is less than n_steps
392+
self.beginIfBlock("step < n_lambda_steps")#
393+
self.addComputeGlobal("restraint_energy", self.restraint_energy)
394+
self.addComputeGlobal("perturbed_pe", "energy - restraint_energy")
395+
self.beginIfBlock("first_step < 1")##
396+
#TODO write better test that checks that the initial work isn't gigantic
397+
self.addComputeGlobal("first_step", "1")
398+
self.addComputeGlobal("restraint_energy", self.restraint_energy)
399+
self.addComputeGlobal("unperturbed_pe", "energy-restraint_energy")
400+
self.endBlock()##
401+
#initial iteration
402+
# Gill put this first:
403+
self.addComputeGlobal("protocol_work",
404+
"protocol_work + (perturbed_pe - unperturbed_pe)"
405+
)
406+
super(AlchemicalNonequilibriumLangevinIntegrator, self)._add_integrator_steps()
407+
logger.info("COMPUTE WORKSSSS")
408+
#if more propogation steps are requested
409+
self.beginIfBlock("lambda > prop_lambda_min")###
410+
self.beginIfBlock("lambda <= prop_lambda_max")####
411+
412+
self.beginWhileBlock("prop < nprop")#####
413+
self.addComputeGlobal("prop", "prop + 1")
414+
415+
super(AlchemicalNonequilibriumLangevinIntegrator, self)._add_integrator_steps()
416+
self.endBlock()#####
417+
self.endBlock()####
418+
self.endBlock()###
419+
#ending variables to reset
420+
self.updateRestraints()
421+
self.addComputeGlobal("restraint_energy", self.restraint_energy)
422+
self.addComputeGlobal("unperturbed_pe", "energy-restraint_energy")
423+
self.addComputeGlobal("step", "step + 1")
424+
self.addComputeGlobal("prop", "1")
425+
426+
self.endBlock()#
427+
428+

blues/moves.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -510,12 +510,27 @@ def __init__(self, structure, residue_list, verbose=False, write_move=False):
510510
self.verbose = verbose
511511
self.write_move = write_move
512512

513+
#def _pmdStructureToOEMol(self):
514+
# """Helper function for converting the parmed structure into an OEMolecule."""
515+
# top = self.structure.topology
516+
# pos = self.structure.positions
517+
# molecule = oeommtools.openmmTop_to_oemol(top, pos)
518+
# oechem.OEPerceiveResidues(molecule)
519+
# oechem.OEFindRingAtomsAndBonds(molecule)
520+
521+
# return molecule
522+
513523
def _pmdStructureToOEMol(self):
514524
"""Helper function for converting the parmed structure into an OEMolecule."""
525+
#structure_LIG = parmed.load_file(prmtop, xyz = inpcrd)
515526
top = self.structure.topology
516527
pos = self.structure.positions
517-
molecule = oeommtools.openmmTop_to_oemol(top, pos)
518-
oechem.OEPerceiveResidues(molecule)
528+
529+
molecule = utils.openmmTop_to_oemol(top, pos, verbose=False)
530+
531+
# Extract coordinates (in Å) and add as conformer
532+
oechem.OEPerceiveBondOrders(molecule)
533+
oechem.OEAssignAromaticFlags(molecule)
519534
oechem.OEFindRingAtomsAndBonds(molecule)
520535

521536
return molecule
@@ -1654,11 +1669,12 @@ def __init__(self, structure, residue_list, verbose=False, write_move=False):
16541669
self.verbose = verbose
16551670
self.write_move = write_move
16561671

1672+
16571673
def _pmdStructureToOEMol(self):
16581674
"""Helper function for converting the parmed structure into an OEMolecule."""
16591675
top = self.structure.topology
16601676
pos = self.structure.positions
1661-
molecule = oeommtools.openmmTop_to_oemol(top, pos, verbose=False)
1677+
molecule = utils.openmmTop_to_oemol(top, pos, verbose=False)
16621678
oechem.OEPerceiveResidues(molecule)
16631679
oechem.OEFindRingAtomsAndBonds(molecule)
16641680

@@ -2075,11 +2091,28 @@ def __init__(self, structure, prmtop, inpcrd, dihedral_atoms, alch_list, resname
20752091
print("AIL:", self.atom_indices_ligand)
20762092
self.dihedral_atoms = dihedral_atoms
20772093
self.positions = structure[self.atom_indices_ligand].positions
2094+
20782095
self.molecule = self._pmdStructureToOEMol(prmtop, inpcrd, resname)
20792096

20802097
def _pmdStructureToOEMol(self, prmtop, inpcrd, resname):
20812098
"""Helper function for converting the parmed structure into an OEMolecule."""
2082-
structure_LIG = parmed.load_file(prmtop, xyz = inpcrd)
2099+
from openmm import XmlSerializer
2100+
from openmm.app import PDBFile
2101+
2102+
# Load system
2103+
with open(prmtop) as f:
2104+
system = XmlSerializer.deserialize(f.read())
2105+
pdb = PDBFile(inpcrd)
2106+
topology = pdb.topology
2107+
positions = pdb.getPositions()
2108+
2109+
2110+
structure_LIG = parmed.openmm.load_topology(
2111+
topology,
2112+
system=system,
2113+
xyz=positions
2114+
)
2115+
#structure_LIG = parmed.load_file(prmtop, xyz = inpcrd)
20832116
mask = "!(:%s)" %resname
20842117
structure_LIG.strip(mask)
20852118
top = structure_LIG.topology

0 commit comments

Comments
 (0)