11import openmm
22from openmmtools .integrators import AlchemicalNonequilibriumLangevinIntegrator
33import logging
4-
4+ from openmm import unit
55logger = logging .getLogger (__name__ )
66# Energy unit used by OpenMM unit system
77_OPENMM_ENERGY_UNIT = openmm .unit .kilojoules_per_mole
@@ -99,7 +99,7 @@ class AlchemicalExternalLangevinIntegrator(AlchemicalNonequilibriumLangevinInteg
9999
100100 def __init__ (self ,
101101 alchemical_functions ,
102- splitting = "R V O H O V R" ,
102+ splitting = "H V R O V R H " ,
103103 temperature = 298.0 * openmm .unit .kelvin ,
104104 collision_rate = 1.0 / openmm .unit .picoseconds ,
105105 timestep = 1.0 * openmm .unit .femtoseconds ,
@@ -135,17 +135,21 @@ def __init__(self,
135135 self .addGlobalVariable ("prop" , 1 )
136136 self .addGlobalVariable ("prop_lambda_min" , self ._prop_lambda [0 ])
137137 self .addGlobalVariable ("prop_lambda_max" , self ._prop_lambda [1 ])
138+ self .addGlobalVariable ("debug_work" , 0.0 )
139+ self .addGlobalVariable ("work_0_to_05" , 0.0 )
140+ self .addGlobalVariable ("work_05_to_1" , 0.0 )
138141 # Behavior changed in https://github.com/choderalab/openmmtools/commit/7c2630050631e126d61b67f56e941de429b2d643#diff-5ce4bc8893e544833c827299a5d48b0d
139142 self ._step_dispatch_table ['H' ] = (self ._add_alchemical_perturbation_step , False )
140143 #$self._registered_step_types['H'] = (
141144 # self._add_alchemical_perturbation_step, False)
142145 self .addGlobalVariable ("debug" , 0 )
143-
146+ logger . info ( f'splitting: { splitting } ' )
144147 try :
145148 self .getGlobalVariableByName ("shadow_work" )
146149 except :
147150 self .addGlobalVariable ('shadow_work' , 0 )
148151
152+
149153 def _get_prop_lambda (self , prop_lambda ):
150154 prop_lambda_max = round (prop_lambda + 0.5 , 4 )
151155 prop_lambda_min = round (0.5 - prop_lambda , 4 )
@@ -157,8 +161,63 @@ def _get_prop_lambda(self, prop_lambda):
157161 prop_lambda_max = - 1.0
158162
159163 return prop_lambda_min , prop_lambda_max
164+
165+
166+ def _add_integrator_steps (self ):
167+ """
168+ Override the base class to insert reset steps around the integrator.
169+ """
170+
171+ # First step: Constrain positions and velocities and reset work accumulators and alchemical integrators
172+ self .beginIfBlock ('step = 0' )
173+ self .addComputeGlobal ("perturbed_pe" , "energy" )
174+ self .addComputeGlobal ("unperturbed_pe" , "energy" )
175+ self .addConstrainPositions ()
176+ self .addConstrainVelocities ()
177+ self ._add_reset_protocol_work_step ()
178+ self ._add_alchemical_reset_step ()
179+ self .endBlock ()
180+
181+ # Main body
182+ if self ._n_steps_neq == 0 :
183+ # If nsteps = 0, we need to force execution on the first step only.
184+ self .beginIfBlock ('step = 0' )
185+ super (AlchemicalNonequilibriumLangevinIntegrator , self )._add_integrator_steps ()
186+ self .addComputeGlobal ("step" , "step + 1" )
187+ self .endBlock ()
188+ else :
189+ #call the superclass function to insert the appropriate steps, provided the step number is less than n_steps
190+ self .beginIfBlock ("step < n_lambda_steps" )
191+ self .addComputeGlobal ("perturbed_pe" , "energy" )
192+ self .beginIfBlock ("first_step < 1" )
193+ #TODO write better test that checks that the initial work isn't gigantic
194+ self .addComputeGlobal ("first_step" , "1" )
195+ self .addComputeGlobal ("unperturbed_pe" , "energy" )
196+ self .endBlock ()
197+ #initial iteration
198+ # Work accumulation is handled in _add_alchemical_perturbation_step()
199+ super (AlchemicalNonequilibriumLangevinIntegrator , self )._add_integrator_steps ()
200+ #if more propogation steps are requested
201+ self .beginIfBlock ("lambda > prop_lambda_min" )
202+ self .beginIfBlock ("lambda <= prop_lambda_max" )
203+
204+ self .beginWhileBlock ("prop < nprop" )
205+ self .addComputeGlobal ("prop" , "prop + 1" )
206+ super (AlchemicalNonequilibriumLangevinIntegrator , self )._add_integrator_steps ()
207+ # Propagation steps - just do additional integration without lambda changes
208+ # The parent's integrator steps are already called above, so we don't need to call them again
209+ # This prevents double work accumulation
210+ self .endBlock ()
211+ self .endBlock ()
212+ self .endBlock ()
213+ #ending variables to reset
214+ self .addComputeGlobal ("unperturbed_pe" , "energy" )
215+ self .addComputeGlobal ("step" , "step + 1" )
216+ self .addComputeGlobal ("prop" , "1" )
217+
218+ self .endBlock ()
219+
160220
161-
162221
163222 def _add_alchemical_perturbation_step (self ):
164223 """
@@ -180,30 +239,190 @@ def _add_alchemical_perturbation_step(self):
180239 # Accumulate protocol work
181240 self .addComputeGlobal ("Enew" , "energy" )
182241 self .addComputeGlobal ("protocol_work" , "protocol_work + (Enew-Eold)" )
242+
243+ # Track work in different phases
244+ self .beginIfBlock ("lambda <= 0.5" )
245+ self .addComputeGlobal ("work_0_to_05" , "work_0_to_05 + (Enew-Eold)" )
246+ self .endBlock ()
247+
248+ self .beginIfBlock ("lambda > 0.5" )
249+ self .addComputeGlobal ("work_05_to_1" , "work_05_to_1 + (Enew-Eold)" )
250+ self .endBlock ()
251+
252+ # Debug: Print work at move step (λ=0.5)
253+ self .beginIfBlock ("abs(lambda - 0.5) < 0.001" )
254+ self .addComputeGlobal ("debug_work" , "Enew - Eold" )
255+ self .endBlock ()
183256 self .endBlock ()
184257
185258 def getLogAcceptanceProbability (self , context ):
186259 #TODO remove context from arguments if/once ncmc_switching is changed
187260 protocol = self .getGlobalVariableByName ("protocol_work" )
188261 shadow = self .getGlobalVariableByName ("shadow_work" )
189262 logp_accept = - 1.0 * (protocol + shadow ) * _OPENMM_ENERGY_UNIT / self .kT
263+ logger .info (f'[WORK] protocol work: { protocol } ' )
264+ logger .info (f'[shadow] shadow: { shadow } ' )
190265
191- # sa
192- import numpy as np
193- logger .info (f"Protocol work: { protocol } " )
194- logger .info (f"Shadow work: { shadow } " )
195- logger .info (f"log_accept_prob: { logp_accept } " )
196- logger .info (f"acceptance_prob: { np .exp (logp_accept )} " )
266+ # Debug: Print work in different phases
267+ try :
268+ work_0_to_05 = self .getGlobalVariableByName ("work_0_to_05" )
269+ work_05_to_1 = self .getGlobalVariableByName ("work_05_to_1" )
270+ debug_work = self .getGlobalVariableByName ("debug_work" )
271+ logger .info (f'[DEBUG] Work 0→0.5: { work_0_to_05 } ' )
272+ logger .info (f'[DEBUG] Work 0.5→1.0: { work_05_to_1 } ' )
273+ logger .info (f'[DEBUG] Work at λ=0.5: { debug_work } ' )
274+ except :
275+ pass
197276
198277 return logp_accept
199278
200279 def reset (self ):
201280 self .setGlobalVariableByName ("step" , 0 )
202281 self .setGlobalVariableByName ("lambda" , 0.0 )
282+ self .setGlobalVariableByName ("lambda_step" , 0.0 )
203283 self .setGlobalVariableByName ("protocol_work" , 0.0 )
204284 self .setGlobalVariableByName ("shadow_work" , 0.0 )
205285 self .setGlobalVariableByName ("first_step" , 0 )
206286 self .setGlobalVariableByName ("perturbed_pe" , 0.0 )
207287 self .setGlobalVariableByName ("unperturbed_pe" , 0.0 )
208288 self .setGlobalVariableByName ("prop" , 1 )
289+ self .setGlobalVariableByName ("debug_work" , 0.0 )
290+ self .setGlobalVariableByName ("work_0_to_05" , 0.0 )
291+ self .setGlobalVariableByName ("work_05_to_1" , 0.0 )
209292 super (AlchemicalExternalLangevinIntegrator , self ).reset ()
293+
294+
295+ #TODO: Add a class for the restrained integrator
296+ # Still need to test the restrained integrator
297+ class AlchemicalExternalRestrainedLangevinIntegrator (AlchemicalExternalLangevinIntegrator ):
298+ def __init__ (self ,
299+ alchemical_functions ,
300+ restraint_group ,
301+ splitting = "R V O H O V R" ,
302+ temperature = 298.0 * unit .kelvin ,
303+ collision_rate = 1.0 / unit .picoseconds ,
304+ timestep = 1.0 * unit .femtoseconds ,
305+ constraint_tolerance = 1e-8 ,
306+ measure_shadow_work = False ,
307+ measure_heat = True ,
308+ nsteps_neq = 0 ,
309+ nprop = 1 ,
310+ prop_lambda = 0.3 ,
311+ lambda_restraints = 'max(0, 1-(1/0.10)*abs(lambda-0.5))' ,
312+ #relax_steps=500, #'max(0, 1-(1/0.10)*abs(lambda-0.5))', #"3*lambda^2 - 2*lambda^3", # old: 'max(0, 1-(1/0.10)*abs(lambda-0.5))'
313+ relax_steps = 50 ,
314+ * args , ** kwargs ):
315+
316+ self .lambda_restraints = lambda_restraints
317+ self .restraint_energy = "energy" + str (restraint_group )
318+
319+ super (AlchemicalExternalRestrainedLangevinIntegrator , self ).__init__ (
320+ alchemical_functions ,
321+ splitting ,
322+ temperature ,
323+ collision_rate ,
324+ timestep ,
325+ constraint_tolerance ,
326+ measure_shadow_work ,
327+ measure_heat ,
328+ nsteps_neq ,
329+ nprop ,
330+ prop_lambda ,
331+ * args , ** kwargs )
332+
333+ try :
334+ self .addGlobalVariable ("restraint_energy" , 0 )
335+ except :
336+ pass
337+ logger .info (f'[LAMBDA STEPS] N_lambda_steps: { self ._n_lambda_steps } ' )
338+ # Only declare NEW variables
339+ self .addGlobalVariable ("debug_lambda" , 0.0 )
340+ #self.addGlobalVariable("restraint_energy", 0.0)
341+
342+ # Set existing globals from parent
343+ # self.setGlobalVariableByName("lambda_step", 0.0)
344+ # self.setGlobalVariableByName("lambda", 0.0)
345+
346+ # Optional debug
347+ self .addComputeGlobal ("debug_lambda" , "lambda" )
348+
349+ # Now safe to use lambda_restraints in update
350+ #self.updateRestraints()
351+
352+ logger .info (f"Current nsteps_neq: { nsteps_neq } " )
353+ logger .info (f'lambda_restraints selected: { self .lambda_restraints } ' )
354+ # compute the mid‐point slice index once
355+ mid = int (self ._n_lambda_steps / 2 )
356+ self .addGlobalVariable ("mid_step" , float (mid ))
357+ self .addGlobalVariable ("relax_counter" , 0.0 )
358+ self .addGlobalVariable ("relax_steps" , float (relax_steps ))
359+
360+ def updateRestraints (self ):
361+ logger .info (f"UPDATE RESTAINTS: { self .lambda_restraints } " )
362+ self .addComputeGlobal ('lambda_restraints' , self .lambda_restraints )
363+
364+
365+ def _add_integrator_steps (self ):
366+ """
367+ Override the base class to insert reset steps around the integrator.
368+ """
369+
370+ # First step: Constrain positions and velocities and reset work accumulators and alchemical integrators
371+ logger .info ("sams protocol" )
372+ self .beginIfBlock ('step = 0' )
373+ self .addComputeGlobal ("restraint_energy" , self .restraint_energy )
374+ self .addComputeGlobal ("perturbed_pe" , "energy - restraint_energy" )
375+ self .addComputeGlobal ("unperturbed_pe" , "energy - restraint_energy" )
376+ self .addConstrainPositions ()
377+ self .addConstrainVelocities ()
378+ self ._add_reset_protocol_work_step ()
379+ self ._add_alchemical_reset_step ()
380+ self .endBlock ()
381+
382+ # Main body
383+
384+ if self ._n_steps_neq == 0 :
385+ # If nsteps = 0, we need to force execution on the first step only.
386+ self .beginIfBlock ('step = 0' )
387+ super (AlchemicalNonequilibriumLangevinIntegrator , self )._add_integrator_steps ()
388+ self .addComputeGlobal ("step" , "step + 1" )
389+ self .endBlock ()
390+ else :
391+ #call the superclass function to insert the appropriate steps, provided the step number is less than n_steps
392+ self .beginIfBlock ("step < n_lambda_steps" )#
393+ self .addComputeGlobal ("restraint_energy" , self .restraint_energy )
394+ self .addComputeGlobal ("perturbed_pe" , "energy - restraint_energy" )
395+ self .beginIfBlock ("first_step < 1" )##
396+ #TODO write better test that checks that the initial work isn't gigantic
397+ self .addComputeGlobal ("first_step" , "1" )
398+ self .addComputeGlobal ("restraint_energy" , self .restraint_energy )
399+ self .addComputeGlobal ("unperturbed_pe" , "energy-restraint_energy" )
400+ self .endBlock ()##
401+ #initial iteration
402+ # Gill put this first:
403+ self .addComputeGlobal ("protocol_work" ,
404+ "protocol_work + (perturbed_pe - unperturbed_pe)"
405+ )
406+ super (AlchemicalNonequilibriumLangevinIntegrator , self )._add_integrator_steps ()
407+ logger .info ("COMPUTE WORKSSSS" )
408+ #if more propogation steps are requested
409+ self .beginIfBlock ("lambda > prop_lambda_min" )###
410+ self .beginIfBlock ("lambda <= prop_lambda_max" )####
411+
412+ self .beginWhileBlock ("prop < nprop" )#####
413+ self .addComputeGlobal ("prop" , "prop + 1" )
414+
415+ super (AlchemicalNonequilibriumLangevinIntegrator , self )._add_integrator_steps ()
416+ self .endBlock ()#####
417+ self .endBlock ()####
418+ self .endBlock ()###
419+ #ending variables to reset
420+ self .updateRestraints ()
421+ self .addComputeGlobal ("restraint_energy" , self .restraint_energy )
422+ self .addComputeGlobal ("unperturbed_pe" , "energy-restraint_energy" )
423+ self .addComputeGlobal ("step" , "step + 1" )
424+ self .addComputeGlobal ("prop" , "1" )
425+
426+ self .endBlock ()#
427+
428+
0 commit comments