Skip to content

Commit 717ff9f

Browse files
committed
Fix double work accumulation in AlchemicalExternalLangevinIntegrator
- Remove redundant work accumulation from _add_integrator_steps() - Fix lambda_step reset in reset() method to prevent lambda drift - Add debug variables for work tracking in different phases - Remove super call from propagation loop to prevent double counting - Fix freeze_radius usage in example_ncmc.py
1 parent 59d0e1f commit 717ff9f

File tree

2 files changed

+46
-3
lines changed

2 files changed

+46
-3
lines changed

blues/integrators.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,9 @@ def __init__(self,
135135
self.addGlobalVariable("prop", 1)
136136
self.addGlobalVariable("prop_lambda_min", self._prop_lambda[0])
137137
self.addGlobalVariable("prop_lambda_max", self._prop_lambda[1])
138+
self.addGlobalVariable("debug_work", 0.0)
139+
self.addGlobalVariable("work_0_to_05", 0.0)
140+
self.addGlobalVariable("work_05_to_1", 0.0)
138141
# Behavior changed in https://github.com/choderalab/openmmtools/commit/7c2630050631e126d61b67f56e941de429b2d643#diff-5ce4bc8893e544833c827299a5d48b0d
139142
self._step_dispatch_table['H'] = (self._add_alchemical_perturbation_step, False)
140143
#$self._registered_step_types['H'] = (
@@ -192,16 +195,18 @@ def _add_integrator_steps(self):
192195
self.addComputeGlobal("unperturbed_pe", "energy")
193196
self.endBlock()
194197
#initial iteration
195-
self.addComputeGlobal("protocol_work", "protocol_work + (perturbed_pe - unperturbed_pe)")
198+
# Work accumulation is handled in _add_alchemical_perturbation_step()
196199
super(AlchemicalNonequilibriumLangevinIntegrator, self)._add_integrator_steps()
197200
#if more propogation steps are requested
198201
self.beginIfBlock("lambda > prop_lambda_min")
199202
self.beginIfBlock("lambda <= prop_lambda_max")
200203

201204
self.beginWhileBlock("prop < nprop")
202205
self.addComputeGlobal("prop", "prop + 1")
203-
204206
super(AlchemicalNonequilibriumLangevinIntegrator, self)._add_integrator_steps()
207+
# Propagation steps - just do additional integration without lambda changes
208+
# The parent's integrator steps are already called above, so we don't need to call them again
209+
# This prevents double work accumulation
205210
self.endBlock()
206211
self.endBlock()
207212
self.endBlock()
@@ -234,6 +239,20 @@ def _add_alchemical_perturbation_step(self):
234239
# Accumulate protocol work
235240
self.addComputeGlobal("Enew", "energy")
236241
self.addComputeGlobal("protocol_work", "protocol_work + (Enew-Eold)")
242+
243+
# Track work in different phases
244+
self.beginIfBlock("lambda <= 0.5")
245+
self.addComputeGlobal("work_0_to_05", "work_0_to_05 + (Enew-Eold)")
246+
self.endBlock()
247+
248+
self.beginIfBlock("lambda > 0.5")
249+
self.addComputeGlobal("work_05_to_1", "work_05_to_1 + (Enew-Eold)")
250+
self.endBlock()
251+
252+
# Debug: Print work at move step (λ=0.5)
253+
self.beginIfBlock("abs(lambda - 0.5) < 0.001")
254+
self.addComputeGlobal("debug_work", "Enew - Eold")
255+
self.endBlock()
237256
self.endBlock()
238257

239258
def getLogAcceptanceProbability(self, context):
@@ -243,17 +262,33 @@ def getLogAcceptanceProbability(self, context):
243262
logp_accept = -1.0 * (protocol + shadow) * _OPENMM_ENERGY_UNIT / self.kT
244263
logger.info(f'[WORK] protocol work: {protocol}')
245264
logger.info(f'[shadow] shadow: {shadow}')
265+
266+
# Debug: Print work in different phases
267+
try:
268+
work_0_to_05 = self.getGlobalVariableByName("work_0_to_05")
269+
work_05_to_1 = self.getGlobalVariableByName("work_05_to_1")
270+
debug_work = self.getGlobalVariableByName("debug_work")
271+
logger.info(f'[DEBUG] Work 0→0.5: {work_0_to_05}')
272+
logger.info(f'[DEBUG] Work 0.5→1.0: {work_05_to_1}')
273+
logger.info(f'[DEBUG] Work at λ=0.5: {debug_work}')
274+
except:
275+
pass
276+
246277
return logp_accept
247278

248279
def reset(self):
249280
self.setGlobalVariableByName("step", 0)
250281
self.setGlobalVariableByName("lambda", 0.0)
282+
self.setGlobalVariableByName("lambda_step", 0.0)
251283
self.setGlobalVariableByName("protocol_work", 0.0)
252284
self.setGlobalVariableByName("shadow_work", 0.0)
253285
self.setGlobalVariableByName("first_step", 0)
254286
self.setGlobalVariableByName("perturbed_pe", 0.0)
255287
self.setGlobalVariableByName("unperturbed_pe", 0.0)
256288
self.setGlobalVariableByName("prop", 1)
289+
self.setGlobalVariableByName("debug_work", 0.0)
290+
self.setGlobalVariableByName("work_0_to_05", 0.0)
291+
self.setGlobalVariableByName("work_05_to_1", 0.0)
257292
super(AlchemicalExternalLangevinIntegrator, self).reset()
258293

259294

blues/simulation.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1072,7 +1072,14 @@ def _stepNCMC(self, nstepsNC, moveStep, move_engine=None):
10721072
#print("Running move_engine.runEngine() at moveStep")
10731073
self._ncmc_sim.context = move_engine.runEngine(self._ncmc_sim.context)
10741074

1075-
1075+
lambda_val = self._ncmc_sim.context._integrator.getGlobalVariableByName("lambda")
1076+
# if lambda val is near 0.0
1077+
if lambda_val < 0.000100 or step == 0 or abs(lambda_val - 0.5) < 1e-4 or abs(lambda_val - 1.0) < 1e-4:
1078+
state = self._ncmc_sim.context.getState(getPositions=True, getEnergy=True)
1079+
logger.info(f"[λ={lambda_val:.6f} and Step={step}] Total potential energy: {state.getPotentialEnergy()}")
1080+
lambda_s = self._ncmc_sim.context.getParameter("lambda_sterics")
1081+
lambda_e = self._ncmc_sim.context.getParameter("lambda_electrostatics")
1082+
logger.info(f"[Step {step}] λ_sterics = {lambda_s}, λ_electrostatics = {lambda_e}")
10761083
self._ncmc_sim.step(1)
10771084

10781085
if step == lastStep:
@@ -1090,6 +1097,7 @@ def _stepNCMC(self, nstepsNC, moveStep, move_engine=None):
10901097
ncmc_state1 = self.getStateFromContext(self._ncmc_sim.context, self._state_keys)
10911098
self._setStateTable('ncmc', 'state1', ncmc_state1)
10921099

1100+
10931101
# # Optional: check difference
10941102
# import numpy as np
10951103
# delta = np.abs(ncmc_state1['positions'] - ncmc_state0['positions'])

0 commit comments

Comments
 (0)