2323
2424from blues import utils
2525from blues .integrators import AlchemicalExternalLangevinIntegrator
26-
26+ from blues .moves import SmartDartMove
27+ import time
2728finfo = np .finfo (np .float32 )
2829rtol = finfo .precision
2930logger = logging .getLogger (__name__ )
@@ -599,7 +600,7 @@ def __init__(self, systems, move_engine, config=None, md_reporters=None, ncmc_re
599600 if ncmc_reporters :
600601 self ._ncmc_reporters = ncmc_reporters
601602 self .ncmc = SimulationFactory .attachReporters (self .ncmc , self ._ncmc_reporters )
602-
603+ logger . info ( f'System Factory config: { config } ' )
603604 @classmethod
604605 def addBarostat (cls , system , temperature = 300 * unit .kelvin , pressure = 1 * unit .atmospheres , frequency = 25 , ** kwargs ):
605606 """
@@ -805,7 +806,7 @@ def generateSimulationSet(self, config=None):
805806
806807 #Initialize the Move Engine with the Alchemical System and NCMC Integrator
807808 for move in self ._move_engine .moves :
808- self ._alch_system , self .ncmc_integrator = move .initializeSystem (self ._alch_system , self .ncmc_integrator )
809+ self ._alch_system , self .ncmc_integrator = move .initializeSystem (self ._alch_system , self .ncmc_integrator , self . config )
809810 self .ncmc = self .generateSimFromStruct (self ._structure , self ._alch_system , self .ncmc_integrator , ** config )
810811 utils .print_host_info (self .ncmc )
811812
@@ -1063,19 +1064,66 @@ def _stepNCMC(self, nstepsNC, moveStep, move_engine=None):
10631064 if not step :
10641065 #print("Calling beforeMove()")
10651066 self ._ncmc_sim .context = move_engine .selected_move .beforeMove (self ._ncmc_sim .context )
1066-
1067+ # NEW: check if move should be skipped
1068+ logger .info (f"MOVE SHOULD BE SKIPPED? : { move_engine .selected_move .skip_ncmc } " )
1069+
1070+ if move_engine .selected_move .skip_ncmc :
1071+ logger .info ("Skipping NCMC steps because no valid move was proposed.\n " )
1072+ logger .info (f"[DEBUG]: the restraint pose is: restraint_pose_{ move_engine .selected_move .current_pose } " )
1073+ lambda_rest = self ._ncmc_sim .context .getParameter ("lambda_restraints" )
1074+ logger .info (f"[Step { step } ] lambda_restraints = { lambda_rest } " )
1075+ positions = self ._md_sim .context .getState (getPositions = True ).getPositions (asNumpy = True )
1076+ move = self ._move_engine .selected_move
1077+ ligand_coords = positions [move .atom_indices ]
1078+ ligand_com = ligand_coords .mean (axis = 0 )
1079+ logger .info (f"[DEBUG] Ligand COM before MD (skipped NCMC): { ligand_com } " )
1080+ break
10671081 if step == moveStep :
10681082 if hasattr (logger , 'report' ):
10691083 logger .info = logger .report
10701084 logger .info ('Performing %s...' % move_engine .move_name )
1071-
1072- #print("Running move_engine.runEngine() at moveStep")
1085+ restraint_pose = self ._ncmc_sim .context .getParameter (f"restraint_pose_{ move_engine .selected_move .current_pose } " )
1086+ lambda_rest = self ._ncmc_sim .context .getParameter ("lambda_restraints" )
1087+ logger .info (f"[Step { step } ] lambda_restraints = { lambda_rest } " )
1088+ logger .info (f"[Step { step } ] restraint_pose_{ move_engine .selected_move .current_pose } = { restraint_pose } " )
1089+ state = self ._ncmc_sim .context .getState (getEnergy = True , groups = {move_engine .selected_move .restraint_group })
1090+ energy = state .getPotentialEnergy ().value_in_unit (unit .kilojoules_per_mole )
1091+ logger .info (f"[Step { step } ] Restraint Energy (group { move_engine .selected_move .restraint_group } ): { energy :.4f} kJ/mol" )
1092+
1093+ try :
1094+ steric_state = self ._ncmc_sim .context .getState (getEnergy = True , groups = {move_engine .selected_move .steric_group })
1095+ steric_energy = steric_state .getPotentialEnergy ().value_in_unit (unit .kilojoules_per_mole )
1096+ logger .info (f"[Step { step } ]Steric Energy During Move Proposal: { steric_energy :.4f} kJ/mol" )
1097+ except Exception as e :
1098+ logger .warning (f"[Step { step } ] Could not retrieve steric energy: { e } " )
1099+ # Perform the NCMC move (lambda 0 → 0.5 and apply move)
10731100 self ._ncmc_sim .context = move_engine .runEngine (self ._ncmc_sim .context )
1074-
10751101
1076- self ._ncmc_sim .step (1 )
1102+ if move_engine .move .selected_move .skip_ncmc == None :
1103+ logger .info ("No valid dart region found. Skipping reverse NCMC and rejecting move." )
1104+ # Call afterMove to clean up / reset lambda
1105+ self ._ncmc_sim .context = move_engine .selected_move .afterMove (self ._ncmc_sim .context )
1106+ # skip remainder of NCMC (e.g., 0.5 → 1.0)
1107+ break
1108+
10771109
1110+ self ._ncmc_sim .step (1 )
1111+ # if step % 50 == 0:
1112+ # lambda_val = self._ncmc_sim.context._integrator.getGlobalVariableByName("lambda")
1113+ # logger.info(f"NCMC step {step}: lambda = {lambda_val}")
10781114 if step == lastStep :
1115+ logger .info ("AFTER MOVE WILL BE CALLED" )
1116+ lambda_val = self ._ncmc_sim .context ._integrator .getGlobalVariableByName ("lambda" )
1117+ logger .info (f"NCMC step { step } : lambda = { lambda_val } " )
1118+
1119+ # Log sterics after move
1120+ try :
1121+ steric_state = self ._ncmc_sim .context .getState (getEnergy = True , groups = {move_engine .selected_move .steric_group })
1122+ steric_energy = steric_state .getPotentialEnergy ().value_in_unit (unit .kilojoules_per_mole )
1123+ logger .info (f"[Step { step } ] Steric Energy AFTER MOVE (group { move_engine .selected_move .steric_group } ): { steric_energy :.4f} kJ/mol" )
1124+ except Exception as e :
1125+ logger .warning (f"[Step { step } ] Could not retrieve steric energy after move: { e } " )
1126+
10791127 self ._ncmc_sim .context = move_engine .selected_move .afterMove (self ._ncmc_sim .context )
10801128 # Debug: print positions after afterMove
10811129
@@ -1090,10 +1138,7 @@ def _stepNCMC(self, nstepsNC, moveStep, move_engine=None):
10901138 ncmc_state1 = self .getStateFromContext (self ._ncmc_sim .context , self ._state_keys )
10911139 self ._setStateTable ('ncmc' , 'state1' , ncmc_state1 )
10921140
1093- # # Optional: check difference
1094- # import numpy as np
1095- # delta = np.abs(ncmc_state1['positions'] - ncmc_state0['positions'])
1096- # print("Max delta between state0 and state1:", np.max(delta))
1141+
10971142
10981143 def _computeAlchemicalCorrection (self ):
10991144 """Computes the alchemical correction term from switching between the NCMC
@@ -1125,44 +1170,94 @@ def _acceptRejectMove(self, write_move=False):
11251170 write_move : bool, default=False
11261171 If True, writes the proposed NCMC move to a PDB file.
11271172 """
1128- work_ncmc = self ._ncmc_sim . context . _integrator . getLogAcceptanceProbability ( self . _ncmc_sim . context )
1129- randnum = math . log ( np . random . random ())
1173+ move = self ._move_engine . selected_move
1174+ acceptance_ratio = move . acceptance_ratio
11301175
1131- # Compute correction if work_ncmc is not NaN
1132- if not np .isnan (work_ncmc ):
1133- correction_factor = self ._computeAlchemicalCorrection ()
1134- logger .debug (
1135- 'NCMCLogAcceptanceProbability = %.6f + Alchemical Correction = %.6f' % (work_ncmc , correction_factor ))
1136- work_ncmc = work_ncmc + correction_factor
1176+ # Case 1: NCMC move was skipped entirely
1177+ if acceptance_ratio is None :
1178+ self .reject += 1
1179+ logger .info ("NCMC MOVE REJECTED: No valid dart region, skipped move." )
1180+ self ._validate_potential_energy_consistency (
1181+ self .stateTable ['md' ]['state0' ], self ._md_sim .context
1182+ )
1183+ return
11371184
1185+ # Get raw log acceptance probability
1186+ work_ncmc = self ._ncmc_sim .context ._integrator .getLogAcceptanceProbability (self ._ncmc_sim .context )
1187+ randnum = math .log (np .random .random ())
1188+ logger .info (f'Protocol of work: { work_ncmc } ' )
1189+ if np .isnan (work_ncmc ):
1190+ self .reject += 1
1191+ logger .warning ("NCMC MOVE REJECTED: work_ncmc is NaN." )
1192+ self ._validate_potential_energy_consistency (
1193+ self .stateTable ['md' ]['state0' ], self ._md_sim .context
1194+ )
1195+ return
1196+
1197+ # Apply alchemical correction
1198+ correction_factor = self ._computeAlchemicalCorrection ()
1199+ logger .info (f"Correcion Factor: { correction_factor } " )
1200+ work_ncmc += correction_factor
1201+
1202+ # Optional: apply statistical restraint correction
1203+ # if acceptance_ratio != 1.0:
1204+ # work_ncmc += math.log(acceptance_ratio)
1205+
1206+ # Metropolis acceptance criterion
11381207 if work_ncmc > randnum :
11391208 self .accept += 1
1140- logger .info ('NCMC MOVE ACCEPTED: work_ncmc {} > randnum {}' .format (work_ncmc , randnum ))
1209+ logger .info (f"NCMC MOVE ACCEPTED: work_ncmc { work_ncmc :.4f} > randnum { randnum :.4f} " )
1210+ for move in self ._move_engine .moves :
1211+ if isinstance (move , SmartDartMove ):
1212+ move .num_accepted_darts += 1
11411213
1142- # If accept move, sync NCMC state to MD context
1214+ # Sync NCMC state to MD context
11431215 ncmc_state1 = self .stateTable ['ncmc' ]['state1' ]
1144- self ._md_sim .context = self .setContextFromState (self ._md_sim .context , ncmc_state1 , velocities = False )
1216+ self ._md_sim .context = self .setContextFromState (
1217+ self ._md_sim .context , ncmc_state1 , velocities = False
1218+ )
11451219
11461220 if write_move :
1147- utils .saveSimulationFrame (self . _md_sim , '{}acc-it{}.pdb' . format ( self . _config [ 'outfname' ],
1148- self .currentIter ))
1149-
1221+ utils .saveSimulationFrame (
1222+ self . _md_sim , f" { self ._config [ 'outfname' ] } acc-it { self . currentIter } .pdb"
1223+ )
11501224 else :
11511225 self .reject += 1
1152- logger .info ('NCMC MOVE REJECTED: work_ncmc {} < {}' .format (work_ncmc , randnum ))
1153-
1154- # If reject move, do nothing,
1155- # NCMC simulation be updated from MD Simulation next iteration.
1156-
1157- # Potential energy should be from last MD step in the previous iteration
1158- md_state0 = self .stateTable ['md' ]['state0' ]
1159- md_PE = self ._md_sim .context .getState (getEnergy = True ).getPotentialEnergy ()
1160- if not math .isclose (md_state0 ['potential_energy' ]._value , md_PE ._value , rel_tol = float ('1e-%s' % rtol )):
1161- logger .error (
1162- 'Last MD potential energy %s != Current MD potential energy %s. Potential energy should match the prior state.'
1163- % (md_state0 ['potential_energy' ], md_PE ))
1164- sys .exit (1 )
1226+ logger .info (f"NCMC MOVE REJECTED: work_ncmc { work_ncmc :.4f} < randnum { randnum :.4f} " )
1227+ self ._validate_potential_energy_consistency (
1228+ self .stateTable ['md' ]['state0' ], self ._md_sim .context
1229+ )
1230+
1231+ def _validate_potential_energy_consistency (self , md_state0 , md_context , rtol = 5 ):
1232+ """
1233+ Validates that the MD potential energy has not drifted after a rejected NCMC move.
11651234
1235+ Parameters
1236+ ----------
1237+ md_state0 : dict
1238+ Stored MD state (before the NCMC move) from self.stateTable['md']['state0'].
1239+ md_context : openmm.Context
1240+ The current OpenMM context for the MD simulation.
1241+ rtol : int, optional
1242+ Relative tolerance for comparing potential energies (default is 1e-8).
1243+
1244+ Raises
1245+ ------
1246+ RuntimeError
1247+ If the current MD potential energy does not match the stored value within tolerance.
1248+ """
1249+ md_PE_current = md_context .getState (getEnergy = True ).getPotentialEnergy ()
1250+ md_PE_stored = md_state0 ['potential_energy' ]
1251+
1252+ if not math .isclose (md_PE_stored ._value , md_PE_current ._value , rel_tol = 10 ** (- rtol )):
1253+ logger .error (
1254+ f"Potential energy mismatch after rejected move:\n "
1255+ f"Stored MD PE: { md_PE_stored } \n "
1256+ f"Current MD PE: { md_PE_current } \n "
1257+ f"Relative diff: { abs (md_PE_stored ._value - md_PE_current ._value )} "
1258+ )
1259+ raise RuntimeError ("Potential energy mismatch: MD state was not preserved correctly after rejection." )
1260+
11661261 def _resetSimulations (self , temperature = None ):
11671262 """At the end of each iteration:
11681263
@@ -1194,6 +1289,16 @@ def _stepMD(self, nstepsMD):
11941289 """
11951290 logger .info ('Advancing %i MD steps...' % (nstepsMD ))
11961291 self ._md_sim .currentIter = self .currentIter
1292+
1293+ move = self ._move_engine .selected_move
1294+ if move .skip_ncmc :
1295+ positions = self ._md_sim .context .getState (getPositions = True ).getPositions (asNumpy = True )
1296+ ligand_com_pre = positions [move .atom_indices ].mean (axis = 0 )
1297+ logger .info (f"[DEBUG] Ligand COM BEFORE MD (iter { self .currentIter } ): { ligand_com_pre } " )
1298+ lambda_rest = self ._ncmc_sim .context .getParameter ("lambda_restraints" )
1299+ logger .info (f"[SKIP-NCMC]: restraints should be off" )
1300+ logger .info (f"[DEBUG] lambda_restraints before MD (iter { self .currentIter } ): { lambda_rest } " )
1301+
11971302 # Retrieve MD state before proposed move
11981303 # Helps determine if previous iteration placed ligand poorly
11991304 md_state0 = self .stateTable ['md' ]['state0' ]
@@ -1210,6 +1315,14 @@ def _stepMD(self, nstepsMD):
12101315 'MD-fail-it%s-md%i.pdb' % (self .currentIter , self ._md_sim .currentStep ))
12111316 sys .exit (1 )
12121317
1318+ # Log ligand COM after MD
1319+ if move .skip_ncmc :
1320+ positions = self ._md_sim .context .getState (getPositions = True ).getPositions (asNumpy = True )
1321+ ligand_com_post = positions [move .atom_indices ].mean (axis = 0 )
1322+ logger .info (f"[DEBUG] Ligand COM AFTER MD (iter { self .currentIter } ): { ligand_com_post } " )
1323+ lambda_rest = self ._ncmc_sim .context .getParameter ("lambda_restraints" )
1324+ logger .info (f"[DEBUG] lambda_restraints AFTER MD (iter { self .currentIter } ): { lambda_rest } " )
1325+
12131326 def run (self , nIter = 0 , nstepsNC = 0 , moveStep = 0 , nstepsMD = 0 , temperature = 300 , write_move = False , ** config ):
12141327 """Executes the BLUES engine to iterate over the actions:
12151328 Perform NCMC simulation, perform proposed move, accepts/rejects move,
@@ -1244,21 +1357,19 @@ def run(self, nIter=0, nstepsNC=0, moveStep=0, nstepsMD=0, temperature=300, writ
12441357 self .currentIter = N
12451358 logger .info ('BLUES Iteration: %s' % N )
12461359 self ._syncStatesMDtoNCMC ()
1247- #print("✅ _syncStatesMDtoNCMC")
12481360 self ._stepNCMC (nstepsNC , moveStep )
1249- #print("✅ _stepNCMC")
12501361 self ._acceptRejectMove (write_move )
1251- #print("✅ _acceptRejectMove")
1252- #print(f'what is temperature: {temperature}')
12531362 self ._resetSimulations (temperature )
1254- #print("✅ _resetSimulations")
12551363 self ._stepMD (nstepsMD )
1256- #print("✅ _stepMD")
1257- #print(f'NITER: {N}/{nIter}')
12581364 # END OF NITER
12591365 self .acceptRatio = self .accept / float (nIter )
12601366 logger .info ('Acceptance Ratio: %s' % self .acceptRatio )
12611367 logger .info ('nIter: %s ' % nIter )
1368+ for move in self ._move_engine .moves :
1369+ if isinstance (move , SmartDartMove ):
1370+ acceptance_ratio = move .num_accepted_darts / float (move .num_proposed_darts )
1371+ logger .info (f"[SmartDartMove] Accepted { move .num_accepted_darts } / { move .num_proposed_darts } darts" )
1372+ logger .info (f"[SmartDartMove] Dart acceptance ratio: { acceptance_ratio :.3f} " )
12621373
12631374class MonteCarloSimulation (BLUESSimulation ):
12641375 """Simulation class provides the functions that perform the MonteCarlo run.
@@ -1298,6 +1409,9 @@ def _acceptRejectMove(self, temperature=None):
12981409 self .accept += 1
12991410 logger .info ('MC MOVE ACCEPTED: work_mc {} > randnum {}' .format (work_mc , randnum ))
13001411 self ._md_sim .context .setPositions (md_state1 ['positions' ])
1412+ for move in self ._move_engine .moves :
1413+ if isinstance (move , SmartDartMove ):
1414+ move .num_accepted_darts += 1
13011415 else :
13021416 self .reject += 1
13031417 logger .info ('MC MOVE REJECTED: work_mc {} < {}' .format (work_mc , randnum ))
@@ -1329,10 +1443,23 @@ def run(self, nIter=0, mc_per_iter=0, nstepsMD=0, temperature=300, write_move=Fa
13291443
13301444 self ._syncStatesMDtoNCMC ()
13311445 for N in range (nIter ):
1446+ iter_start = time .time ()
13321447 self .currentIter = N
13331448 logger .info ('MonteCarlo Iteration: %s' % N )
13341449 for i in range (mc_per_iter ):
13351450 self ._syncStatesMDtoNCMC ()
13361451 self ._stepMC_ ()
13371452 self ._acceptRejectMove (temperature )
13381453 self ._stepMD (nstepsMD )
1454+ iter_end = time .time ()
1455+ logger .info (f"Iteration { N } took { iter_end - iter_start :.2f} sec" )
1456+
1457+ self .acceptRatio = self .accept / float (nIter )
1458+ logger .info ('Acceptance Ratio: %s' % self .acceptRatio )
1459+ logger .info ('nIter: %s ' % nIter )
1460+ # After move, access found_dart
1461+ for move in self ._move_engine .moves :
1462+ if isinstance (move , SmartDartMove ):
1463+ acceptance_ratio = move .num_accepted_darts / float (move .num_proposed_darts )
1464+ logger .info (f"[SmartDartMove] Accepted { move .num_accepted_darts } / { move .num_proposed_darts } darts" )
1465+ logger .info (f"[SmartDartMove] Dart acceptance ratio: { acceptance_ratio :.3f} " )
0 commit comments