@@ -135,6 +135,9 @@ def __init__(self,
135135 self .addGlobalVariable ("prop" , 1 )
136136 self .addGlobalVariable ("prop_lambda_min" , self ._prop_lambda [0 ])
137137 self .addGlobalVariable ("prop_lambda_max" , self ._prop_lambda [1 ])
138+ self .addGlobalVariable ("debug_work" , 0.0 )
139+ self .addGlobalVariable ("work_0_to_05" , 0.0 )
140+ self .addGlobalVariable ("work_05_to_1" , 0.0 )
138141 # Behavior changed in https://github.com/choderalab/openmmtools/commit/7c2630050631e126d61b67f56e941de429b2d643#diff-5ce4bc8893e544833c827299a5d48b0d
139142 self ._step_dispatch_table ['H' ] = (self ._add_alchemical_perturbation_step , False )
140143 #$self._registered_step_types['H'] = (
@@ -192,16 +195,18 @@ def _add_integrator_steps(self):
192195 self .addComputeGlobal ("unperturbed_pe" , "energy" )
193196 self .endBlock ()
194197 #initial iteration
195- self . addComputeGlobal ( "protocol_work" , "protocol_work + (perturbed_pe - unperturbed_pe)" )
198+ # Work accumulation is handled in _add_alchemical_perturbation_step( )
196199 super (AlchemicalNonequilibriumLangevinIntegrator , self )._add_integrator_steps ()
197200 #if more propogation steps are requested
198201 self .beginIfBlock ("lambda > prop_lambda_min" )
199202 self .beginIfBlock ("lambda <= prop_lambda_max" )
200203
201204 self .beginWhileBlock ("prop < nprop" )
202205 self .addComputeGlobal ("prop" , "prop + 1" )
203-
204206 super (AlchemicalNonequilibriumLangevinIntegrator , self )._add_integrator_steps ()
207+ # Propagation steps - just do additional integration without lambda changes
208+ # The parent's integrator steps are already called above, so we don't need to call them again
209+ # This prevents double work accumulation
205210 self .endBlock ()
206211 self .endBlock ()
207212 self .endBlock ()
@@ -234,6 +239,20 @@ def _add_alchemical_perturbation_step(self):
234239 # Accumulate protocol work
235240 self .addComputeGlobal ("Enew" , "energy" )
236241 self .addComputeGlobal ("protocol_work" , "protocol_work + (Enew-Eold)" )
242+
243+ # Track work in different phases
244+ self .beginIfBlock ("lambda <= 0.5" )
245+ self .addComputeGlobal ("work_0_to_05" , "work_0_to_05 + (Enew-Eold)" )
246+ self .endBlock ()
247+
248+ self .beginIfBlock ("lambda > 0.5" )
249+ self .addComputeGlobal ("work_05_to_1" , "work_05_to_1 + (Enew-Eold)" )
250+ self .endBlock ()
251+
252+ # Debug: Print work at move step (λ=0.5)
253+ self .beginIfBlock ("abs(lambda - 0.5) < 0.001" )
254+ self .addComputeGlobal ("debug_work" , "Enew - Eold" )
255+ self .endBlock ()
237256 self .endBlock ()
238257
239258 def getLogAcceptanceProbability (self , context ):
@@ -243,17 +262,33 @@ def getLogAcceptanceProbability(self, context):
243262 logp_accept = - 1.0 * (protocol + shadow ) * _OPENMM_ENERGY_UNIT / self .kT
244263 logger .info (f'[WORK] protocol work: { protocol } ' )
245264 logger .info (f'[shadow] shadow: { shadow } ' )
265+
266+ # Debug: Print work in different phases
267+ try :
268+ work_0_to_05 = self .getGlobalVariableByName ("work_0_to_05" )
269+ work_05_to_1 = self .getGlobalVariableByName ("work_05_to_1" )
270+ debug_work = self .getGlobalVariableByName ("debug_work" )
271+ logger .info (f'[DEBUG] Work 0→0.5: { work_0_to_05 } ' )
272+ logger .info (f'[DEBUG] Work 0.5→1.0: { work_05_to_1 } ' )
273+ logger .info (f'[DEBUG] Work at λ=0.5: { debug_work } ' )
274+ except :
275+ pass
276+
246277 return logp_accept
247278
248279 def reset (self ):
249280 self .setGlobalVariableByName ("step" , 0 )
250281 self .setGlobalVariableByName ("lambda" , 0.0 )
282+ self .setGlobalVariableByName ("lambda_step" , 0.0 )
251283 self .setGlobalVariableByName ("protocol_work" , 0.0 )
252284 self .setGlobalVariableByName ("shadow_work" , 0.0 )
253285 self .setGlobalVariableByName ("first_step" , 0 )
254286 self .setGlobalVariableByName ("perturbed_pe" , 0.0 )
255287 self .setGlobalVariableByName ("unperturbed_pe" , 0.0 )
256288 self .setGlobalVariableByName ("prop" , 1 )
289+ self .setGlobalVariableByName ("debug_work" , 0.0 )
290+ self .setGlobalVariableByName ("work_0_to_05" , 0.0 )
291+ self .setGlobalVariableByName ("work_05_to_1" , 0.0 )
257292 super (AlchemicalExternalLangevinIntegrator , self ).reset ()
258293
259294
0 commit comments