Skip to content

Commit 0e3fab1

Browse files
committed
BeforeMove and AfterMove() turn off boresch restraints
1 parent 5022a1e commit 0e3fab1

File tree

2 files changed

+163
-20
lines changed

2 files changed

+163
-20
lines changed

blues/moves.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def __init__(self,
203203
restraints=None, K_r=None, K_angle=None,
204204
restrained_receptor_atoms=None, restrained_ligand_atoms=None,
205205
lambda_restraints = "max(0, 1-(1/0.10)*abs(lambda-0.5))",
206-
no_move=False,
206+
skip_move=False,
207207
):
208208
self.structure = structure
209209
self.resname = resname
@@ -221,7 +221,7 @@ def __init__(self,
221221
self.restrained_receptor_atoms = restrained_receptor_atoms
222222
self.restrained_ligand_atoms = restrained_ligand_atoms
223223
self.positions = structure[self.atom_indices].positions
224-
self.no_move = no_move
224+
self.skip_move = skip_move
225225
if self.ligand_indices:
226226
self.topology = structure[self.ligand_indices].topology
227227
self.positions = structure[self.ligand_indices].positions
@@ -468,6 +468,52 @@ def _calculateProperties(self):
468468
print(f"🚨 WARNING: Mass count ({len(self.masses)}) does not match atom count ({len(self.atom_indices)})!")
469469

470470
self.center_of_mass = self.getCenterOfMass(self.positions, self.masses)
471+
472+
def beforeMove(self, context):
473+
474+
"""
475+
Called before NCMC begins. Checks if ligand is in a predefined dart region,
476+
and sets context parameters accordingly. If not, NCMC will be skipped.
477+
"""
478+
if not self.restraints:
479+
return context
480+
for i in range(len(self.binding_mode_traj)):
481+
context.setParameter(f'restraint_pose_{i}', 1.0)
482+
483+
return context
484+
485+
def afterMove(self, context):
486+
"""
487+
If restraints were specified,Check if current positions are in
488+
the same pose as the specified restraint.
489+
If not, reject the move (to maintain detailed balance).
490+
491+
This method is called at the end of the NCMC portion if the
492+
context needs to be checked or modified before performing the move
493+
at the halfway point.
494+
495+
Parameters
496+
----------
497+
context: simtk.openmm.Context object
498+
Context containing the positions to be moved.
499+
Returns
500+
-------
501+
context: simtk.openmm.Context object
502+
The same input context, but whose context were changed by this function.
503+
504+
"""
505+
if not self.restraints:
506+
return context
507+
508+
# Turn off all restraints after move
509+
for i in range(len(self.binding_mode_traj)):
510+
context.setParameter(f'restraint_pose_{i}', 0.0)
511+
context.setParameter('lambda_restraints', 0.0)
512+
context.setParameter("lambda_sterics", 1.0)
513+
context.setParameter('lambda_electrostatics', 1.0)
514+
515+
return context
516+
471517
def move(self, context):
472518
"""Function that performs a random rotation about the
473519
center of mass of the ligand.
@@ -482,7 +528,7 @@ def move(self, context):
482528
context: openmm.openmm.Context object
483529
The same input context, but whose positions were changed by this function.
484530
"""
485-
if self.no_move:
531+
if not self.skip_move:
486532
return context
487533
positions = context.getState(getPositions=True).getPositions(asNumpy=True)
488534

blues/simulation.py

Lines changed: 114 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,52 +1040,89 @@ def _syncStatesMDtoNCMC(self):
10401040
# Sync MD state to the NCMC context
10411041
self._ncmc_sim.context = self.setContextFromState(self._ncmc_sim.context, md_state0)
10421042

1043+
10431044
def _stepNCMC(self, nstepsNC, moveStep, move_engine=None):
10441045
"""Advance the NCMC simulation."""
1045-
#print("Running _stepNCMC...")
1046-
#print(f"nsteps: {nstepsNC}, moveStep: {moveStep}")
1047-
logger.info('Advancing %i NCMC switching steps...' % (nstepsNC))
10481046

1049-
# Retrieve NCMC state before proposed move
10501047
ncmc_state0 = self.getStateFromContext(self._ncmc_sim.context, self._state_keys)
1051-
#print("Captured ncmc_state0")
10521048
#print(ncmc_state0['positions'])
10531049
self._setStateTable('ncmc', 'state0', ncmc_state0)
1054-
1050+
logger.info(f"SetTable ncmc_state0")
10551051
# Select the move to perform
10561052
if not move_engine:
10571053
move_engine = self._move_engine
10581054
self._ncmc_sim.currentIter = self.currentIter
1059-
1055+
logger.info(f'move engine selected: {move_engine}')
10601056
move_engine.selectMove()
10611057
#print(f"Selected move: {move_engine.move_name}")
10621058

10631059
lastStep = nstepsNC - 1
1060+
logger.info(f"LOOPING: nstepsNC: {nstepsNC}")
10641061
for step in range(int(nstepsNC)):
10651062
try:
10661063
if not step:
1067-
#print("Calling beforeMove()")
1064+
logger.info("Calling beforeMove()")
10681065
self._ncmc_sim.context = move_engine.selected_move.beforeMove(self._ncmc_sim.context)
1069-
10701066
if step == moveStep:
10711067
if hasattr(logger, 'report'):
10721068
logger.info = logger.report
10731069
logger.info('Performing %s...' % move_engine.move_name)
10741070

1075-
#print("Running move_engine.runEngine() at moveStep")
1071+
try:
1072+
steric_state = self._ncmc_sim.context.getState(getEnergy=True, groups={move_engine.selected_move.steric_group})
1073+
steric_energy = steric_state.getPotentialEnergy().value_in_unit(unit.kilojoules_per_mole)
1074+
logger.info(f"[Step {step}]Steric Energy During Move Proposal: {steric_energy:.4f} kJ/mol")
1075+
except Exception as e:
1076+
logger.warning(f"[Step {step}] Could not retrieve steric energy: {e}")
1077+
# Perform the NCMC move (lambda 0 → 0.5 and apply move)
10761078
self._ncmc_sim.context = move_engine.runEngine(self._ncmc_sim.context)
1077-
1079+
if move_engine.selected_move.acceptance_ratio == None:
1080+
logger.info("No valid dart region found. Skipping reverse NCMC and rejecting move.")
1081+
# Call afterMove to clean up / reset lambda
1082+
self._ncmc_sim.context = move_engine.selected_move.afterMove(self._ncmc_sim.context)
1083+
# skip remainder of NCMC (e.g., 0.5 → 1.0)
1084+
break
1085+
#
1086+
# state = self._ncmc_sim.context.getState(getPositions=True, getEnergy=True)
1087+
# steric_state = self._ncmc_sim.context.getState(getEnergy=True, groups={move_engine.selected_move.steric_group})
1088+
# steric_energy = steric_state.getPotentialEnergy().value_in_unit(unit.kilojoules_per_mole)
10781089
lambda_val = self._ncmc_sim.context._integrator.getGlobalVariableByName("lambda")
1079-
# if lambda val is near 0.0
1080-
if lambda_val < 0.000100 or step == 0 or abs(lambda_val - 0.5) < 1e-4 or abs(lambda_val - 1.0) < 1e-4:
1090+
if abs(lambda_val - 0.0) < 1e-4 or abs(lambda_val - 0.5) < 1e-4 or abs(lambda_val - 0.9) < 1e-4:
1091+
steric_state = self._ncmc_sim.context.getState(getEnergy=True, groups={move_engine.selected_move.steric_group})
1092+
steric_energy = steric_state.getPotentialEnergy().value_in_unit(unit.kilojoules_per_mole)
10811093
state = self._ncmc_sim.context.getState(getPositions=True, getEnergy=True)
10821094
logger.info(f"[λ={lambda_val:.6f} and Step={step}] Total potential energy: {state.getPotentialEnergy()}")
1095+
logger.info(f"[λ={lambda_val:.6f} and Step={step}] Steric Energy: {steric_energy:.4f} kJ/mol and Total potential energy: {state.getPotentialEnergy()} ")
10831096
lambda_s = self._ncmc_sim.context.getParameter("lambda_sterics")
10841097
lambda_e = self._ncmc_sim.context.getParameter("lambda_electrostatics")
10851098
logger.info(f"[Step {step}] λ_sterics = {lambda_s}, λ_electrostatics = {lambda_e}")
1099+
#self._log_restraint_energies(step=step, move_engine = move_engine)
1100+
10861101
self._ncmc_sim.step(1)
10871102

1103+
integrator = self._ncmc_sim.context._integrator
1104+
lambda_val = integrator.getGlobalVariableByName("lambda")
1105+
#logger.info(f"[NCMC step {step}] lambda = {lambda_val:.6f} | protocol_work = {protocol_work:.4f}")
1106+
#self._log_restraint_energies(step=step, move_engine = move_engine, print_output=False)
1107+
10881108
if step == lastStep:
1109+
logger.info("AFTER MOVE WILL BE CALLED")
1110+
lambda_val = self._ncmc_sim.context._integrator.getGlobalVariableByName("lambda")
1111+
logger.info(f"NCMC step {step}: lambda = {lambda_val}")
1112+
state = self._ncmc_sim.context.getState(getEnergy=True)
1113+
logger.info(f"[λ={lambda_val:.6f} and Step={step}] Total potential energy: {state.getPotentialEnergy()}")
1114+
lambda_s = self._ncmc_sim.context.getParameter("lambda_sterics")
1115+
lambda_e = self._ncmc_sim.context.getParameter("lambda_electrostatics")
1116+
logger.info(f"[Step {step}] λ_sterics = {lambda_s}, λ_electrostatics = {lambda_e}")
1117+
#self._log_restraint_energies(step=step, move_engine = move_engine)
1118+
# Log sterics after move
1119+
try:
1120+
steric_state = self._ncmc_sim.context.getState(getEnergy=True, groups={move_engine.selected_move.steric_group})
1121+
steric_energy = steric_state.getPotentialEnergy().value_in_unit(unit.kilojoules_per_mole)
1122+
logger.info(f"[Step {step}] Steric Energy AFTER MOVE (group {move_engine.selected_move.steric_group}): {steric_energy:.4f} kJ/mol")
1123+
except Exception as e:
1124+
logger.warning(f"[Step {step}] Could not retrieve steric energy after move: {e}")
1125+
10891126
self._ncmc_sim.context = move_engine.selected_move.afterMove(self._ncmc_sim.context)
10901127
# Debug: print positions after afterMove
10911128

@@ -1100,11 +1137,71 @@ def _stepNCMC(self, nstepsNC, moveStep, move_engine=None):
11001137
ncmc_state1 = self.getStateFromContext(self._ncmc_sim.context, self._state_keys)
11011138
self._setStateTable('ncmc', 'state1', ncmc_state1)
11021139

1140+
# def _stepNCMC(self, nstepsNC, moveStep, move_engine=None):
1141+
# """Advance the NCMC simulation."""
1142+
# #print("Running _stepNCMC...")
1143+
# #print(f"nsteps: {nstepsNC}, moveStep: {moveStep}")
1144+
# logger.info('Advancing %i NCMC switching steps...' % (nstepsNC))
1145+
1146+
# # Retrieve NCMC state before proposed move
1147+
# ncmc_state0 = self.getStateFromContext(self._ncmc_sim.context, self._state_keys)
1148+
# #print("Captured ncmc_state0")
1149+
# #print(ncmc_state0['positions'])
1150+
# self._setStateTable('ncmc', 'state0', ncmc_state0)
1151+
1152+
# # Select the move to perform
1153+
# if not move_engine:
1154+
# move_engine = self._move_engine
1155+
# self._ncmc_sim.currentIter = self.currentIter
1156+
1157+
# move_engine.selectMove()
1158+
# #print(f"Selected move: {move_engine.move_name}")
1159+
1160+
# lastStep = nstepsNC - 1
1161+
# for step in range(int(nstepsNC)):
1162+
# try:
1163+
# if not step:
1164+
# #print("Calling beforeMove()")
1165+
# self._ncmc_sim.context = move_engine.selected_move.beforeMove(self._ncmc_sim.context)
1166+
1167+
# if step == moveStep:
1168+
# if hasattr(logger, 'report'):
1169+
# logger.info = logger.report
1170+
# logger.info('Performing %s...' % move_engine.move_name)
1171+
1172+
# #print("Running move_engine.runEngine() at moveStep")
1173+
# self._ncmc_sim.context = move_engine.runEngine(self._ncmc_sim.context)
1174+
1175+
# lambda_val = self._ncmc_sim.context._integrator.getGlobalVariableByName("lambda")
1176+
# # if lambda val is near 0.0
1177+
# if lambda_val < 0.000100 or step == 0 or abs(lambda_val - 0.5) < 1e-4 or abs(lambda_val - 1.0) < 1e-4:
1178+
# state = self._ncmc_sim.context.getState(getPositions=True, getEnergy=True)
1179+
# logger.info(f"[λ={lambda_val:.6f} and Step={step}] Total potential energy: {state.getPotentialEnergy()}")
1180+
# lambda_s = self._ncmc_sim.context.getParameter("lambda_sterics")
1181+
# lambda_e = self._ncmc_sim.context.getParameter("lambda_electrostatics")
1182+
# logger.info(f"[Step {step}] λ_sterics = {lambda_s}, λ_electrostatics = {lambda_e}")
1183+
# self._ncmc_sim.step(1)
1184+
1185+
# if step == lastStep:
1186+
# self._ncmc_sim.context = move_engine.selected_move.afterMove(self._ncmc_sim.context)
1187+
# # Debug: print positions after afterMove
1188+
1189+
# except Exception as e:
1190+
# import traceback
1191+
# traceback.print_tb(e.__traceback__)
1192+
# logger.error(e)
1193+
# move_engine.selected_move._error(self._ncmc_sim.context)
1194+
# break
1195+
1196+
# # ncmc_state1 stores the state AFTER a proposed move
1197+
# ncmc_state1 = self.getStateFromContext(self._ncmc_sim.context, self._state_keys)
1198+
# self._setStateTable('ncmc', 'state1', ncmc_state1)
1199+
11031200

1104-
# # Optional: check difference
1105-
# import numpy as np
1106-
# delta = np.abs(ncmc_state1['positions'] - ncmc_state0['positions'])
1107-
# print("Max delta between state0 and state1:", np.max(delta))
1201+
# # # Optional: check difference
1202+
# # import numpy as np
1203+
# # delta = np.abs(ncmc_state1['positions'] - ncmc_state0['positions'])
1204+
# # print("Max delta between state0 and state1:", np.max(delta))
11081205

11091206
def _computeAlchemicalCorrection(self):
11101207
"""Computes the alchemical correction term from switching between the NCMC

0 commit comments

Comments
 (0)