Skip to content

Commit eb80081

Browse files
author
Steven Ayoub
committed
Log SmartDartMove dart acceptance and iteration timing in MonteCarloSimulation
- Imported SmartDartMove and time for performance logging - Tracked wall-clock duration of each NCMC iteration - Calculated and logged acceptance ratio for SmartDartMove proposals and accepted darts - Included config dump for debug inspection in SimulationFactory
1 parent 0f46361 commit eb80081

File tree

1 file changed

+173
-46
lines changed

1 file changed

+173
-46
lines changed

blues/simulation.py

Lines changed: 173 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323

2424
from blues import utils
2525
from blues.integrators import AlchemicalExternalLangevinIntegrator
26-
26+
from blues.moves import SmartDartMove
27+
import time
2728
finfo = np.finfo(np.float32)
2829
rtol = finfo.precision
2930
logger = logging.getLogger(__name__)
@@ -599,7 +600,7 @@ def __init__(self, systems, move_engine, config=None, md_reporters=None, ncmc_re
599600
if ncmc_reporters:
600601
self._ncmc_reporters = ncmc_reporters
601602
self.ncmc = SimulationFactory.attachReporters(self.ncmc, self._ncmc_reporters)
602-
603+
logger.info(f'System Factory config: {config}')
603604
@classmethod
604605
def addBarostat(cls, system, temperature=300 * unit.kelvin, pressure=1 * unit.atmospheres, frequency=25, **kwargs):
605606
"""
@@ -805,7 +806,7 @@ def generateSimulationSet(self, config=None):
805806

806807
#Initialize the Move Engine with the Alchemical System and NCMC Integrator
807808
for move in self._move_engine.moves:
808-
self._alch_system, self.ncmc_integrator = move.initializeSystem(self._alch_system, self.ncmc_integrator)
809+
self._alch_system, self.ncmc_integrator = move.initializeSystem(self._alch_system, self.ncmc_integrator, self.config)
809810
self.ncmc = self.generateSimFromStruct(self._structure, self._alch_system, self.ncmc_integrator, **config)
810811
utils.print_host_info(self.ncmc)
811812

@@ -1063,19 +1064,66 @@ def _stepNCMC(self, nstepsNC, moveStep, move_engine=None):
10631064
if not step:
10641065
#print("Calling beforeMove()")
10651066
self._ncmc_sim.context = move_engine.selected_move.beforeMove(self._ncmc_sim.context)
1066-
1067+
# NEW: check if move should be skipped
1068+
logger.info(f"MOVE SHOULD BE SKIPPED? : {move_engine.selected_move.skip_ncmc}")
1069+
1070+
if move_engine.selected_move.skip_ncmc:
1071+
logger.info("Skipping NCMC steps because no valid move was proposed.\n")
1072+
logger.info(f"[DEBUG]: the restraint pose is: restraint_pose_{move_engine.selected_move.current_pose}")
1073+
lambda_rest = self._ncmc_sim.context.getParameter("lambda_restraints")
1074+
logger.info(f"[Step {step}] lambda_restraints = {lambda_rest}")
1075+
positions = self._md_sim.context.getState(getPositions=True).getPositions(asNumpy=True)
1076+
move = self._move_engine.selected_move
1077+
ligand_coords = positions[move.atom_indices]
1078+
ligand_com = ligand_coords.mean(axis=0)
1079+
logger.info(f"[DEBUG] Ligand COM before MD (skipped NCMC): {ligand_com}")
1080+
break
10671081
if step == moveStep:
10681082
if hasattr(logger, 'report'):
10691083
logger.info = logger.report
10701084
logger.info('Performing %s...' % move_engine.move_name)
1071-
1072-
#print("Running move_engine.runEngine() at moveStep")
1085+
restraint_pose = self._ncmc_sim.context.getParameter(f"restraint_pose_{move_engine.selected_move.current_pose}")
1086+
lambda_rest = self._ncmc_sim.context.getParameter("lambda_restraints")
1087+
logger.info(f"[Step {step}] lambda_restraints = {lambda_rest}")
1088+
logger.info(f"[Step {step}] restraint_pose_{move_engine.selected_move.current_pose} = {restraint_pose}")
1089+
state = self._ncmc_sim.context.getState(getEnergy=True, groups={move_engine.selected_move.restraint_group})
1090+
energy = state.getPotentialEnergy().value_in_unit(unit.kilojoules_per_mole)
1091+
logger.info(f"[Step {step}] Restraint Energy (group {move_engine.selected_move.restraint_group}): {energy:.4f} kJ/mol")
1092+
1093+
try:
1094+
steric_state = self._ncmc_sim.context.getState(getEnergy=True, groups={move_engine.selected_move.steric_group})
1095+
steric_energy = steric_state.getPotentialEnergy().value_in_unit(unit.kilojoules_per_mole)
1096+
logger.info(f"[Step {step}]Steric Energy During Move Proposal: {steric_energy:.4f} kJ/mol")
1097+
except Exception as e:
1098+
logger.warning(f"[Step {step}] Could not retrieve steric energy: {e}")
1099+
# Perform the NCMC move (lambda 0 → 0.5 and apply move)
10731100
self._ncmc_sim.context = move_engine.runEngine(self._ncmc_sim.context)
1074-
10751101

1076-
self._ncmc_sim.step(1)
1102+
if move_engine.move.selected_move.skip_ncmc == None:
1103+
logger.info("No valid dart region found. Skipping reverse NCMC and rejecting move.")
1104+
# Call afterMove to clean up / reset lambda
1105+
self._ncmc_sim.context = move_engine.selected_move.afterMove(self._ncmc_sim.context)
1106+
# skip remainder of NCMC (e.g., 0.5 → 1.0)
1107+
break
1108+
10771109

1110+
self._ncmc_sim.step(1)
1111+
# if step % 50 == 0:
1112+
# lambda_val = self._ncmc_sim.context._integrator.getGlobalVariableByName("lambda")
1113+
# logger.info(f"NCMC step {step}: lambda = {lambda_val}")
10781114
if step == lastStep:
1115+
logger.info("AFTER MOVE WILL BE CALLED")
1116+
lambda_val = self._ncmc_sim.context._integrator.getGlobalVariableByName("lambda")
1117+
logger.info(f"NCMC step {step}: lambda = {lambda_val}")
1118+
1119+
# Log sterics after move
1120+
try:
1121+
steric_state = self._ncmc_sim.context.getState(getEnergy=True, groups={move_engine.selected_move.steric_group})
1122+
steric_energy = steric_state.getPotentialEnergy().value_in_unit(unit.kilojoules_per_mole)
1123+
logger.info(f"[Step {step}] Steric Energy AFTER MOVE (group {move_engine.selected_move.steric_group}): {steric_energy:.4f} kJ/mol")
1124+
except Exception as e:
1125+
logger.warning(f"[Step {step}] Could not retrieve steric energy after move: {e}")
1126+
10791127
self._ncmc_sim.context = move_engine.selected_move.afterMove(self._ncmc_sim.context)
10801128
# Debug: print positions after afterMove
10811129

@@ -1090,10 +1138,7 @@ def _stepNCMC(self, nstepsNC, moveStep, move_engine=None):
10901138
ncmc_state1 = self.getStateFromContext(self._ncmc_sim.context, self._state_keys)
10911139
self._setStateTable('ncmc', 'state1', ncmc_state1)
10921140

1093-
# # Optional: check difference
1094-
# import numpy as np
1095-
# delta = np.abs(ncmc_state1['positions'] - ncmc_state0['positions'])
1096-
# print("Max delta between state0 and state1:", np.max(delta))
1141+
10971142

10981143
def _computeAlchemicalCorrection(self):
10991144
"""Computes the alchemical correction term from switching between the NCMC
@@ -1125,44 +1170,94 @@ def _acceptRejectMove(self, write_move=False):
11251170
write_move : bool, default=False
11261171
If True, writes the proposed NCMC move to a PDB file.
11271172
"""
1128-
work_ncmc = self._ncmc_sim.context._integrator.getLogAcceptanceProbability(self._ncmc_sim.context)
1129-
randnum = math.log(np.random.random())
1173+
move = self._move_engine.selected_move
1174+
acceptance_ratio = move.acceptance_ratio
11301175

1131-
# Compute correction if work_ncmc is not NaN
1132-
if not np.isnan(work_ncmc):
1133-
correction_factor = self._computeAlchemicalCorrection()
1134-
logger.debug(
1135-
'NCMCLogAcceptanceProbability = %.6f + Alchemical Correction = %.6f' % (work_ncmc, correction_factor))
1136-
work_ncmc = work_ncmc + correction_factor
1176+
# Case 1: NCMC move was skipped entirely
1177+
if acceptance_ratio is None:
1178+
self.reject += 1
1179+
logger.info("NCMC MOVE REJECTED: No valid dart region, skipped move.")
1180+
self._validate_potential_energy_consistency(
1181+
self.stateTable['md']['state0'], self._md_sim.context
1182+
)
1183+
return
11371184

1185+
# Get raw log acceptance probability
1186+
work_ncmc = self._ncmc_sim.context._integrator.getLogAcceptanceProbability(self._ncmc_sim.context)
1187+
randnum = math.log(np.random.random())
1188+
logger.info(f'Protocol of work: {work_ncmc}')
1189+
if np.isnan(work_ncmc):
1190+
self.reject += 1
1191+
logger.warning("NCMC MOVE REJECTED: work_ncmc is NaN.")
1192+
self._validate_potential_energy_consistency(
1193+
self.stateTable['md']['state0'], self._md_sim.context
1194+
)
1195+
return
1196+
1197+
# Apply alchemical correction
1198+
correction_factor = self._computeAlchemicalCorrection()
1199+
logger.info(f"Correcion Factor: {correction_factor}")
1200+
work_ncmc += correction_factor
1201+
1202+
# Optional: apply statistical restraint correction
1203+
# if acceptance_ratio != 1.0:
1204+
# work_ncmc += math.log(acceptance_ratio)
1205+
1206+
# Metropolis acceptance criterion
11381207
if work_ncmc > randnum:
11391208
self.accept += 1
1140-
logger.info('NCMC MOVE ACCEPTED: work_ncmc {} > randnum {}'.format(work_ncmc, randnum))
1209+
logger.info(f"NCMC MOVE ACCEPTED: work_ncmc {work_ncmc:.4f} > randnum {randnum:.4f}")
1210+
for move in self._move_engine.moves:
1211+
if isinstance(move, SmartDartMove):
1212+
move.num_accepted_darts += 1
11411213

1142-
# If accept move, sync NCMC state to MD context
1214+
# Sync NCMC state to MD context
11431215
ncmc_state1 = self.stateTable['ncmc']['state1']
1144-
self._md_sim.context = self.setContextFromState(self._md_sim.context, ncmc_state1, velocities=False)
1216+
self._md_sim.context = self.setContextFromState(
1217+
self._md_sim.context, ncmc_state1, velocities=False
1218+
)
11451219

11461220
if write_move:
1147-
utils.saveSimulationFrame(self._md_sim, '{}acc-it{}.pdb'.format(self._config['outfname'],
1148-
self.currentIter))
1149-
1221+
utils.saveSimulationFrame(
1222+
self._md_sim, f"{self._config['outfname']}acc-it{self.currentIter}.pdb"
1223+
)
11501224
else:
11511225
self.reject += 1
1152-
logger.info('NCMC MOVE REJECTED: work_ncmc {} < {}'.format(work_ncmc, randnum))
1153-
1154-
# If reject move, do nothing,
1155-
# NCMC simulation be updated from MD Simulation next iteration.
1156-
1157-
# Potential energy should be from last MD step in the previous iteration
1158-
md_state0 = self.stateTable['md']['state0']
1159-
md_PE = self._md_sim.context.getState(getEnergy=True).getPotentialEnergy()
1160-
if not math.isclose(md_state0['potential_energy']._value, md_PE._value, rel_tol=float('1e-%s' % rtol)):
1161-
logger.error(
1162-
'Last MD potential energy %s != Current MD potential energy %s. Potential energy should match the prior state.'
1163-
% (md_state0['potential_energy'], md_PE))
1164-
sys.exit(1)
1226+
logger.info(f"NCMC MOVE REJECTED: work_ncmc {work_ncmc:.4f} < randnum {randnum:.4f}")
1227+
self._validate_potential_energy_consistency(
1228+
self.stateTable['md']['state0'], self._md_sim.context
1229+
)
1230+
1231+
def _validate_potential_energy_consistency(self, md_state0, md_context, rtol=5):
1232+
"""
1233+
Validates that the MD potential energy has not drifted after a rejected NCMC move.
11651234
1235+
Parameters
1236+
----------
1237+
md_state0 : dict
1238+
Stored MD state (before the NCMC move) from self.stateTable['md']['state0'].
1239+
md_context : openmm.Context
1240+
The current OpenMM context for the MD simulation.
1241+
rtol : int, optional
1242+
Relative tolerance for comparing potential energies (default is 1e-8).
1243+
1244+
Raises
1245+
------
1246+
RuntimeError
1247+
If the current MD potential energy does not match the stored value within tolerance.
1248+
"""
1249+
md_PE_current = md_context.getState(getEnergy=True).getPotentialEnergy()
1250+
md_PE_stored = md_state0['potential_energy']
1251+
1252+
if not math.isclose(md_PE_stored._value, md_PE_current._value, rel_tol=10**(-rtol)):
1253+
logger.error(
1254+
f"Potential energy mismatch after rejected move:\n"
1255+
f"Stored MD PE: {md_PE_stored}\n"
1256+
f"Current MD PE: {md_PE_current}\n"
1257+
f"Relative diff: {abs(md_PE_stored._value - md_PE_current._value)}"
1258+
)
1259+
raise RuntimeError("Potential energy mismatch: MD state was not preserved correctly after rejection.")
1260+
11661261
def _resetSimulations(self, temperature=None):
11671262
"""At the end of each iteration:
11681263
@@ -1194,6 +1289,16 @@ def _stepMD(self, nstepsMD):
11941289
"""
11951290
logger.info('Advancing %i MD steps...' % (nstepsMD))
11961291
self._md_sim.currentIter = self.currentIter
1292+
1293+
move = self._move_engine.selected_move
1294+
if move.skip_ncmc:
1295+
positions = self._md_sim.context.getState(getPositions=True).getPositions(asNumpy=True)
1296+
ligand_com_pre = positions[move.atom_indices].mean(axis=0)
1297+
logger.info(f"[DEBUG] Ligand COM BEFORE MD (iter {self.currentIter}): {ligand_com_pre}")
1298+
lambda_rest = self._ncmc_sim.context.getParameter("lambda_restraints")
1299+
logger.info(f"[SKIP-NCMC]: restraints should be off")
1300+
logger.info(f"[DEBUG] lambda_restraints before MD (iter {self.currentIter}): {lambda_rest}")
1301+
11971302
# Retrieve MD state before proposed move
11981303
# Helps determine if previous iteration placed ligand poorly
11991304
md_state0 = self.stateTable['md']['state0']
@@ -1210,6 +1315,14 @@ def _stepMD(self, nstepsMD):
12101315
'MD-fail-it%s-md%i.pdb' % (self.currentIter, self._md_sim.currentStep))
12111316
sys.exit(1)
12121317

1318+
# Log ligand COM after MD
1319+
if move.skip_ncmc:
1320+
positions = self._md_sim.context.getState(getPositions=True).getPositions(asNumpy=True)
1321+
ligand_com_post = positions[move.atom_indices].mean(axis=0)
1322+
logger.info(f"[DEBUG] Ligand COM AFTER MD (iter {self.currentIter}): {ligand_com_post}")
1323+
lambda_rest = self._ncmc_sim.context.getParameter("lambda_restraints")
1324+
logger.info(f"[DEBUG] lambda_restraints AFTER MD (iter {self.currentIter}): {lambda_rest}")
1325+
12131326
def run(self, nIter=0, nstepsNC=0, moveStep=0, nstepsMD=0, temperature=300, write_move=False, **config):
12141327
"""Executes the BLUES engine to iterate over the actions:
12151328
Perform NCMC simulation, perform proposed move, accepts/rejects move,
@@ -1244,21 +1357,19 @@ def run(self, nIter=0, nstepsNC=0, moveStep=0, nstepsMD=0, temperature=300, writ
12441357
self.currentIter = N
12451358
logger.info('BLUES Iteration: %s' % N)
12461359
self._syncStatesMDtoNCMC()
1247-
#print("✅ _syncStatesMDtoNCMC")
12481360
self._stepNCMC(nstepsNC, moveStep)
1249-
#print("✅ _stepNCMC")
12501361
self._acceptRejectMove(write_move)
1251-
#print("✅ _acceptRejectMove")
1252-
#print(f'what is temperature: {temperature}')
12531362
self._resetSimulations(temperature)
1254-
#print("✅ _resetSimulations")
12551363
self._stepMD(nstepsMD)
1256-
#print("✅ _stepMD")
1257-
#print(f'NITER: {N}/{nIter}')
12581364
# END OF NITER
12591365
self.acceptRatio = self.accept / float(nIter)
12601366
logger.info('Acceptance Ratio: %s' % self.acceptRatio)
12611367
logger.info('nIter: %s ' % nIter)
1368+
for move in self._move_engine.moves:
1369+
if isinstance(move, SmartDartMove):
1370+
acceptance_ratio = move.num_accepted_darts / float(move.num_proposed_darts)
1371+
logger.info(f"[SmartDartMove] Accepted {move.num_accepted_darts} / {move.num_proposed_darts} darts")
1372+
logger.info(f"[SmartDartMove] Dart acceptance ratio: {acceptance_ratio:.3f}")
12621373

12631374
class MonteCarloSimulation(BLUESSimulation):
12641375
"""Simulation class provides the functions that perform the MonteCarlo run.
@@ -1298,6 +1409,9 @@ def _acceptRejectMove(self, temperature=None):
12981409
self.accept += 1
12991410
logger.info('MC MOVE ACCEPTED: work_mc {} > randnum {}'.format(work_mc, randnum))
13001411
self._md_sim.context.setPositions(md_state1['positions'])
1412+
for move in self._move_engine.moves:
1413+
if isinstance(move, SmartDartMove):
1414+
move.num_accepted_darts += 1
13011415
else:
13021416
self.reject += 1
13031417
logger.info('MC MOVE REJECTED: work_mc {} < {}'.format(work_mc, randnum))
@@ -1329,10 +1443,23 @@ def run(self, nIter=0, mc_per_iter=0, nstepsMD=0, temperature=300, write_move=Fa
13291443

13301444
self._syncStatesMDtoNCMC()
13311445
for N in range(nIter):
1446+
iter_start = time.time()
13321447
self.currentIter = N
13331448
logger.info('MonteCarlo Iteration: %s' % N)
13341449
for i in range(mc_per_iter):
13351450
self._syncStatesMDtoNCMC()
13361451
self._stepMC_()
13371452
self._acceptRejectMove(temperature)
13381453
self._stepMD(nstepsMD)
1454+
iter_end = time.time()
1455+
logger.info(f"Iteration {N} took {iter_end - iter_start:.2f} sec")
1456+
1457+
self.acceptRatio = self.accept / float(nIter)
1458+
logger.info('Acceptance Ratio: %s' % self.acceptRatio)
1459+
logger.info('nIter: %s ' % nIter)
1460+
# After move, access found_dart
1461+
for move in self._move_engine.moves:
1462+
if isinstance(move, SmartDartMove):
1463+
acceptance_ratio = move.num_accepted_darts / float(move.num_proposed_darts)
1464+
logger.info(f"[SmartDartMove] Accepted {move.num_accepted_darts} / {move.num_proposed_darts} darts")
1465+
logger.info(f"[SmartDartMove] Dart acceptance ratio: {acceptance_ratio:.3f}")

0 commit comments

Comments
 (0)