11import openmm
22from openmmtools .integrators import AlchemicalNonequilibriumLangevinIntegrator
33import logging
4-
4+ from openmm import unit
55logger = logging .getLogger (__name__ )
66# Energy unit used by OpenMM unit system
77_OPENMM_ENERGY_UNIT = openmm .unit .kilojoules_per_mole
@@ -140,12 +140,13 @@ def __init__(self,
140140 #$self._registered_step_types['H'] = (
141141 # self._add_alchemical_perturbation_step, False)
142142 self .addGlobalVariable ("debug" , 0 )
143-
143+ logger . info ( f'splitting: { splitting } ' )
144144 try :
145145 self .getGlobalVariableByName ("shadow_work" )
146146 except :
147147 self .addGlobalVariable ('shadow_work' , 0 )
148148
149+
149150 def _get_prop_lambda (self , prop_lambda ):
150151 prop_lambda_max = round (prop_lambda + 0.5 , 4 )
151152 prop_lambda_min = round (0.5 - prop_lambda , 4 )
@@ -157,8 +158,61 @@ def _get_prop_lambda(self, prop_lambda):
157158 prop_lambda_max = - 1.0
158159
159160 return prop_lambda_min , prop_lambda_max
161+
162+
163+ def _add_integrator_steps (self ):
164+ """
165+ Override the base class to insert reset steps around the integrator.
166+ """
167+
168+ # First step: Constrain positions and velocities and reset work accumulators and alchemical integrators
169+ self .beginIfBlock ('step = 0' )
170+ self .addComputeGlobal ("perturbed_pe" , "energy" )
171+ self .addComputeGlobal ("unperturbed_pe" , "energy" )
172+ self .addConstrainPositions ()
173+ self .addConstrainVelocities ()
174+ self ._add_reset_protocol_work_step ()
175+ self ._add_alchemical_reset_step ()
176+ self .endBlock ()
177+
178+ # Main body
179+ if self ._n_steps_neq == 0 :
180+ # If nsteps = 0, we need to force execution on the first step only.
181+ self .beginIfBlock ('step = 0' )
182+ super (AlchemicalNonequilibriumLangevinIntegrator , self )._add_integrator_steps ()
183+ self .addComputeGlobal ("step" , "step + 1" )
184+ self .endBlock ()
185+ else :
186+ #call the superclass function to insert the appropriate steps, provided the step number is less than n_steps
187+ self .beginIfBlock ("step < n_lambda_steps" )
188+ self .addComputeGlobal ("perturbed_pe" , "energy" )
189+ self .beginIfBlock ("first_step < 1" )
190+ #TODO write better test that checks that the initial work isn't gigantic
191+ self .addComputeGlobal ("first_step" , "1" )
192+ self .addComputeGlobal ("unperturbed_pe" , "energy" )
193+ self .endBlock ()
194+ #initial iteration
195+ self .addComputeGlobal ("protocol_work" , "protocol_work + (perturbed_pe - unperturbed_pe)" )
196+ super (AlchemicalNonequilibriumLangevinIntegrator , self )._add_integrator_steps ()
197+ #if more propogation steps are requested
198+ self .beginIfBlock ("lambda > prop_lambda_min" )
199+ self .beginIfBlock ("lambda <= prop_lambda_max" )
200+
201+ self .beginWhileBlock ("prop < nprop" )
202+ self .addComputeGlobal ("prop" , "prop + 1" )
203+
204+ super (AlchemicalNonequilibriumLangevinIntegrator , self )._add_integrator_steps ()
205+ self .endBlock ()
206+ self .endBlock ()
207+ self .endBlock ()
208+ #ending variables to reset
209+ self .addComputeGlobal ("unperturbed_pe" , "energy" )
210+ self .addComputeGlobal ("step" , "step + 1" )
211+ self .addComputeGlobal ("prop" , "1" )
212+
213+ self .endBlock ()
214+
160215
161-
162216
163217 def _add_alchemical_perturbation_step (self ):
164218 """
@@ -187,14 +241,8 @@ def getLogAcceptanceProbability(self, context):
187241 protocol = self .getGlobalVariableByName ("protocol_work" )
188242 shadow = self .getGlobalVariableByName ("shadow_work" )
189243 logp_accept = - 1.0 * (protocol + shadow ) * _OPENMM_ENERGY_UNIT / self .kT
190-
191- # sa
192- import numpy as np
193- logger .info (f"Protocol work: { protocol } " )
194- logger .info (f"Shadow work: { shadow } " )
195- logger .info (f"log_accept_prob: { logp_accept } " )
196- logger .info (f"acceptance_prob: { np .exp (logp_accept )} " )
197-
244+ logger .info (f'[WORK] protocol work: { protocol } ' )
245+ logger .info (f'[shadow] shadow: { shadow } ' )
198246 return logp_accept
199247
200248 def reset (self ):
@@ -207,3 +255,138 @@ def reset(self):
207255 self .setGlobalVariableByName ("unperturbed_pe" , 0.0 )
208256 self .setGlobalVariableByName ("prop" , 1 )
209257 super (AlchemicalExternalLangevinIntegrator , self ).reset ()
258+
259+
260+ #TODO: Add a class for the restrained integrator
261+ # Still need to test the restrained integrator
262+ class AlchemicalExternalRestrainedLangevinIntegrator (AlchemicalExternalLangevinIntegrator ):
263+ def __init__ (self ,
264+ alchemical_functions ,
265+ restraint_group ,
266+ splitting = "R V O H O V R" ,
267+ temperature = 298.0 * unit .kelvin ,
268+ collision_rate = 1.0 / unit .picoseconds ,
269+ timestep = 1.0 * unit .femtoseconds ,
270+ constraint_tolerance = 1e-8 ,
271+ measure_shadow_work = False ,
272+ measure_heat = True ,
273+ nsteps_neq = 0 ,
274+ nprop = 1 ,
275+ prop_lambda = 0.3 ,
276+ lambda_restraints = 'max(0, 1-(1/0.10)*abs(lambda-0.5))' ,
277+ #relax_steps=500, #'max(0, 1-(1/0.10)*abs(lambda-0.5))', #"3*lambda^2 - 2*lambda^3", # old: 'max(0, 1-(1/0.10)*abs(lambda-0.5))'
278+ relax_steps = 50 ,
279+ * args , ** kwargs ):
280+
281+ self .lambda_restraints = lambda_restraints
282+ self .restraint_energy = "energy" + str (restraint_group )
283+
284+ super (AlchemicalExternalRestrainedLangevinIntegrator , self ).__init__ (
285+ alchemical_functions ,
286+ splitting ,
287+ temperature ,
288+ collision_rate ,
289+ timestep ,
290+ constraint_tolerance ,
291+ measure_shadow_work ,
292+ measure_heat ,
293+ nsteps_neq ,
294+ nprop ,
295+ prop_lambda ,
296+ * args , ** kwargs )
297+
298+ try :
299+ self .addGlobalVariable ("restraint_energy" , 0 )
300+ except :
301+ pass
302+ logger .info (f'[LAMBDA STEPS] N_lambda_steps: { self ._n_lambda_steps } ' )
303+ # Only declare NEW variables
304+ self .addGlobalVariable ("debug_lambda" , 0.0 )
305+ #self.addGlobalVariable("restraint_energy", 0.0)
306+
307+ # Set existing globals from parent
308+ # self.setGlobalVariableByName("lambda_step", 0.0)
309+ # self.setGlobalVariableByName("lambda", 0.0)
310+
311+ # Optional debug
312+ self .addComputeGlobal ("debug_lambda" , "lambda" )
313+
314+ # Now safe to use lambda_restraints in update
315+ #self.updateRestraints()
316+
317+ logger .info (f"Current nsteps_neq: { nsteps_neq } " )
318+ logger .info (f'lambda_restraints selected: { self .lambda_restraints } ' )
319+ # compute the mid‐point slice index once
320+ mid = int (self ._n_lambda_steps / 2 )
321+ self .addGlobalVariable ("mid_step" , float (mid ))
322+ self .addGlobalVariable ("relax_counter" , 0.0 )
323+ self .addGlobalVariable ("relax_steps" , float (relax_steps ))
324+
325+ def updateRestraints (self ):
326+ logger .info (f"UPDATE RESTAINTS: { self .lambda_restraints } " )
327+ self .addComputeGlobal ('lambda_restraints' , self .lambda_restraints )
328+
329+
330+ def _add_integrator_steps (self ):
331+ """
332+ Override the base class to insert reset steps around the integrator.
333+ """
334+
335+ # First step: Constrain positions and velocities and reset work accumulators and alchemical integrators
336+ logger .info ("sams protocol" )
337+ self .beginIfBlock ('step = 0' )
338+ self .addComputeGlobal ("restraint_energy" , self .restraint_energy )
339+ self .addComputeGlobal ("perturbed_pe" , "energy - restraint_energy" )
340+ self .addComputeGlobal ("unperturbed_pe" , "energy - restraint_energy" )
341+ self .addConstrainPositions ()
342+ self .addConstrainVelocities ()
343+ self ._add_reset_protocol_work_step ()
344+ self ._add_alchemical_reset_step ()
345+ self .endBlock ()
346+
347+ # Main body
348+
349+ if self ._n_steps_neq == 0 :
350+ # If nsteps = 0, we need to force execution on the first step only.
351+ self .beginIfBlock ('step = 0' )
352+ super (AlchemicalNonequilibriumLangevinIntegrator , self )._add_integrator_steps ()
353+ self .addComputeGlobal ("step" , "step + 1" )
354+ self .endBlock ()
355+ else :
356+ #call the superclass function to insert the appropriate steps, provided the step number is less than n_steps
357+ self .beginIfBlock ("step < n_lambda_steps" )#
358+ self .addComputeGlobal ("restraint_energy" , self .restraint_energy )
359+ self .addComputeGlobal ("perturbed_pe" , "energy - restraint_energy" )
360+ self .beginIfBlock ("first_step < 1" )##
361+ #TODO write better test that checks that the initial work isn't gigantic
362+ self .addComputeGlobal ("first_step" , "1" )
363+ self .addComputeGlobal ("restraint_energy" , self .restraint_energy )
364+ self .addComputeGlobal ("unperturbed_pe" , "energy-restraint_energy" )
365+ self .endBlock ()##
366+ #initial iteration
367+ # Gill put this first:
368+ self .addComputeGlobal ("protocol_work" ,
369+ "protocol_work + (perturbed_pe - unperturbed_pe)"
370+ )
371+ super (AlchemicalNonequilibriumLangevinIntegrator , self )._add_integrator_steps ()
372+ logger .info ("COMPUTE WORKSSSS" )
373+ #if more propogation steps are requested
374+ self .beginIfBlock ("lambda > prop_lambda_min" )###
375+ self .beginIfBlock ("lambda <= prop_lambda_max" )####
376+
377+ self .beginWhileBlock ("prop < nprop" )#####
378+ self .addComputeGlobal ("prop" , "prop + 1" )
379+
380+ super (AlchemicalNonequilibriumLangevinIntegrator , self )._add_integrator_steps ()
381+ self .endBlock ()#####
382+ self .endBlock ()####
383+ self .endBlock ()###
384+ #ending variables to reset
385+ self .updateRestraints ()
386+ self .addComputeGlobal ("restraint_energy" , self .restraint_energy )
387+ self .addComputeGlobal ("unperturbed_pe" , "energy-restraint_energy" )
388+ self .addComputeGlobal ("step" , "step + 1" )
389+ self .addComputeGlobal ("prop" , "1" )
390+
391+ self .endBlock ()#
392+
0 commit comments