Skip to content

Commit 59d0e1f

Browse files
author
Steven Ayoub
committed
Current support version of integrator seems to be computing larger work than expected
1 parent a10a42c commit 59d0e1f

File tree

1 file changed

+194
-11
lines changed

1 file changed

+194
-11
lines changed

blues/integrators.py

Lines changed: 194 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import openmm
22
from openmmtools.integrators import AlchemicalNonequilibriumLangevinIntegrator
33
import logging
4-
4+
from openmm import unit
55
logger = logging.getLogger(__name__)
66
# Energy unit used by OpenMM unit system
77
_OPENMM_ENERGY_UNIT = openmm.unit.kilojoules_per_mole
@@ -140,12 +140,13 @@ def __init__(self,
140140
#$self._registered_step_types['H'] = (
141141
# self._add_alchemical_perturbation_step, False)
142142
self.addGlobalVariable("debug", 0)
143-
143+
logger.info(f'splitting: {splitting}')
144144
try:
145145
self.getGlobalVariableByName("shadow_work")
146146
except:
147147
self.addGlobalVariable('shadow_work', 0)
148148

149+
149150
def _get_prop_lambda(self, prop_lambda):
150151
prop_lambda_max = round(prop_lambda + 0.5, 4)
151152
prop_lambda_min = round(0.5 - prop_lambda, 4)
@@ -157,8 +158,61 @@ def _get_prop_lambda(self, prop_lambda):
157158
prop_lambda_max = -1.0
158159

159160
return prop_lambda_min, prop_lambda_max
161+
162+
163+
def _add_integrator_steps(self):
164+
"""
165+
Override the base class to insert reset steps around the integrator.
166+
"""
167+
168+
# First step: Constrain positions and velocities and reset work accumulators and alchemical integrators
169+
self.beginIfBlock('step = 0')
170+
self.addComputeGlobal("perturbed_pe", "energy")
171+
self.addComputeGlobal("unperturbed_pe", "energy")
172+
self.addConstrainPositions()
173+
self.addConstrainVelocities()
174+
self._add_reset_protocol_work_step()
175+
self._add_alchemical_reset_step()
176+
self.endBlock()
177+
178+
# Main body
179+
if self._n_steps_neq == 0:
180+
# If nsteps = 0, we need to force execution on the first step only.
181+
self.beginIfBlock('step = 0')
182+
super(AlchemicalNonequilibriumLangevinIntegrator, self)._add_integrator_steps()
183+
self.addComputeGlobal("step", "step + 1")
184+
self.endBlock()
185+
else:
186+
#call the superclass function to insert the appropriate steps, provided the step number is less than n_steps
187+
self.beginIfBlock("step < n_lambda_steps")
188+
self.addComputeGlobal("perturbed_pe", "energy")
189+
self.beginIfBlock("first_step < 1")
190+
#TODO write better test that checks that the initial work isn't gigantic
191+
self.addComputeGlobal("first_step", "1")
192+
self.addComputeGlobal("unperturbed_pe", "energy")
193+
self.endBlock()
194+
#initial iteration
195+
self.addComputeGlobal("protocol_work", "protocol_work + (perturbed_pe - unperturbed_pe)")
196+
super(AlchemicalNonequilibriumLangevinIntegrator, self)._add_integrator_steps()
197+
#if more propogation steps are requested
198+
self.beginIfBlock("lambda > prop_lambda_min")
199+
self.beginIfBlock("lambda <= prop_lambda_max")
200+
201+
self.beginWhileBlock("prop < nprop")
202+
self.addComputeGlobal("prop", "prop + 1")
203+
204+
super(AlchemicalNonequilibriumLangevinIntegrator, self)._add_integrator_steps()
205+
self.endBlock()
206+
self.endBlock()
207+
self.endBlock()
208+
#ending variables to reset
209+
self.addComputeGlobal("unperturbed_pe", "energy")
210+
self.addComputeGlobal("step", "step + 1")
211+
self.addComputeGlobal("prop", "1")
212+
213+
self.endBlock()
214+
160215

161-
162216

163217
def _add_alchemical_perturbation_step(self):
164218
"""
@@ -187,14 +241,8 @@ def getLogAcceptanceProbability(self, context):
187241
protocol = self.getGlobalVariableByName("protocol_work")
188242
shadow = self.getGlobalVariableByName("shadow_work")
189243
logp_accept = -1.0 * (protocol + shadow) * _OPENMM_ENERGY_UNIT / self.kT
190-
191-
# sa
192-
import numpy as np
193-
logger.info(f"Protocol work: {protocol}")
194-
logger.info(f"Shadow work: {shadow}")
195-
logger.info(f"log_accept_prob: {logp_accept}")
196-
logger.info(f"acceptance_prob: {np.exp(logp_accept)}")
197-
244+
logger.info(f'[WORK] protocol work: {protocol}')
245+
logger.info(f'[shadow] shadow: {shadow}')
198246
return logp_accept
199247

200248
def reset(self):
@@ -207,3 +255,138 @@ def reset(self):
207255
self.setGlobalVariableByName("unperturbed_pe", 0.0)
208256
self.setGlobalVariableByName("prop", 1)
209257
super(AlchemicalExternalLangevinIntegrator, self).reset()
258+
259+
260+
#TODO: Add a class for the restrained integrator
261+
# Still need to test the restrained integrator
262+
class AlchemicalExternalRestrainedLangevinIntegrator(AlchemicalExternalLangevinIntegrator):
263+
def __init__(self,
264+
alchemical_functions,
265+
restraint_group,
266+
splitting="R V O H O V R",
267+
temperature=298.0 * unit.kelvin,
268+
collision_rate=1.0 / unit.picoseconds,
269+
timestep=1.0 * unit.femtoseconds,
270+
constraint_tolerance=1e-8,
271+
measure_shadow_work=False,
272+
measure_heat=True,
273+
nsteps_neq=0,
274+
nprop=1,
275+
prop_lambda=0.3,
276+
lambda_restraints = 'max(0, 1-(1/0.10)*abs(lambda-0.5))',
277+
#relax_steps=500, #'max(0, 1-(1/0.10)*abs(lambda-0.5))', #"3*lambda^2 - 2*lambda^3", # old: 'max(0, 1-(1/0.10)*abs(lambda-0.5))'
278+
relax_steps=50,
279+
*args, **kwargs):
280+
281+
self.lambda_restraints = lambda_restraints
282+
self.restraint_energy = "energy"+str(restraint_group)
283+
284+
super(AlchemicalExternalRestrainedLangevinIntegrator, self).__init__(
285+
alchemical_functions,
286+
splitting,
287+
temperature,
288+
collision_rate,
289+
timestep,
290+
constraint_tolerance,
291+
measure_shadow_work,
292+
measure_heat,
293+
nsteps_neq,
294+
nprop,
295+
prop_lambda,
296+
*args, **kwargs)
297+
298+
try:
299+
self.addGlobalVariable("restraint_energy", 0)
300+
except:
301+
pass
302+
logger.info(f'[LAMBDA STEPS] N_lambda_steps: {self._n_lambda_steps}')
303+
# Only declare NEW variables
304+
self.addGlobalVariable("debug_lambda", 0.0)
305+
#self.addGlobalVariable("restraint_energy", 0.0)
306+
307+
# Set existing globals from parent
308+
# self.setGlobalVariableByName("lambda_step", 0.0)
309+
# self.setGlobalVariableByName("lambda", 0.0)
310+
311+
# Optional debug
312+
self.addComputeGlobal("debug_lambda", "lambda")
313+
314+
# Now safe to use lambda_restraints in update
315+
#self.updateRestraints()
316+
317+
logger.info(f"Current nsteps_neq: {nsteps_neq}")
318+
logger.info(f'lambda_restraints selected: {self.lambda_restraints}')
319+
# compute the mid‐point slice index once
320+
mid = int(self._n_lambda_steps/2)
321+
self.addGlobalVariable("mid_step", float(mid))
322+
self.addGlobalVariable("relax_counter", 0.0)
323+
self.addGlobalVariable("relax_steps", float(relax_steps))
324+
325+
def updateRestraints(self):
326+
logger.info(f"UPDATE RESTAINTS: {self.lambda_restraints}")
327+
self.addComputeGlobal('lambda_restraints', self.lambda_restraints)
328+
329+
330+
def _add_integrator_steps(self):
331+
"""
332+
Override the base class to insert reset steps around the integrator.
333+
"""
334+
335+
# First step: Constrain positions and velocities and reset work accumulators and alchemical integrators
336+
logger.info("sams protocol")
337+
self.beginIfBlock('step = 0')
338+
self.addComputeGlobal("restraint_energy", self.restraint_energy)
339+
self.addComputeGlobal("perturbed_pe", "energy - restraint_energy")
340+
self.addComputeGlobal("unperturbed_pe", "energy - restraint_energy")
341+
self.addConstrainPositions()
342+
self.addConstrainVelocities()
343+
self._add_reset_protocol_work_step()
344+
self._add_alchemical_reset_step()
345+
self.endBlock()
346+
347+
# Main body
348+
349+
if self._n_steps_neq == 0:
350+
# If nsteps = 0, we need to force execution on the first step only.
351+
self.beginIfBlock('step = 0')
352+
super(AlchemicalNonequilibriumLangevinIntegrator, self)._add_integrator_steps()
353+
self.addComputeGlobal("step", "step + 1")
354+
self.endBlock()
355+
else:
356+
#call the superclass function to insert the appropriate steps, provided the step number is less than n_steps
357+
self.beginIfBlock("step < n_lambda_steps")#
358+
self.addComputeGlobal("restraint_energy", self.restraint_energy)
359+
self.addComputeGlobal("perturbed_pe", "energy - restraint_energy")
360+
self.beginIfBlock("first_step < 1")##
361+
#TODO write better test that checks that the initial work isn't gigantic
362+
self.addComputeGlobal("first_step", "1")
363+
self.addComputeGlobal("restraint_energy", self.restraint_energy)
364+
self.addComputeGlobal("unperturbed_pe", "energy-restraint_energy")
365+
self.endBlock()##
366+
#initial iteration
367+
# Gill put this first:
368+
self.addComputeGlobal("protocol_work",
369+
"protocol_work + (perturbed_pe - unperturbed_pe)"
370+
)
371+
super(AlchemicalNonequilibriumLangevinIntegrator, self)._add_integrator_steps()
372+
logger.info("COMPUTE WORKSSSS")
373+
#if more propogation steps are requested
374+
self.beginIfBlock("lambda > prop_lambda_min")###
375+
self.beginIfBlock("lambda <= prop_lambda_max")####
376+
377+
self.beginWhileBlock("prop < nprop")#####
378+
self.addComputeGlobal("prop", "prop + 1")
379+
380+
super(AlchemicalNonequilibriumLangevinIntegrator, self)._add_integrator_steps()
381+
self.endBlock()#####
382+
self.endBlock()####
383+
self.endBlock()###
384+
#ending variables to reset
385+
self.updateRestraints()
386+
self.addComputeGlobal("restraint_energy", self.restraint_energy)
387+
self.addComputeGlobal("unperturbed_pe", "energy-restraint_energy")
388+
self.addComputeGlobal("step", "step + 1")
389+
self.addComputeGlobal("prop", "1")
390+
391+
self.endBlock()#
392+

0 commit comments

Comments
 (0)