Skip to content

Commit 5850f07

Browse files
committed
Add pose-aware Boresch and RMSD restraints with optional COM term
1 parent 32af689 commit 5850f07

File tree

2 files changed

+454
-0
lines changed

2 files changed

+454
-0
lines changed

blues/moves.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222
from openmm import unit
2323
import tempfile
2424
import numpy as np
25+
import openmm
26+
27+
from blues.integrators import AlchemicalExternalLangevinIntegrator, AlchemicalExternalRestrainedLangevinIntegrator
28+
from blues.restraints import add_boresch_restraints
2529

2630

2731
try:
@@ -208,6 +212,152 @@ def __init__(self, structure, resname='LIG', ligand_indices=None, random_state=N
208212

209213
self._calculateProperties()
210214

215+
216+
def initializeSystem(self, system, integrator, config):
217+
"""
218+
Changes the system by adding forces corresponding to restraints (if specified)
219+
and freeze protein and/or waters, if specified in __init__()
220+
221+
222+
Parameters
223+
----------
224+
system : simtk.openmm.System object
225+
System to be modified.
226+
integrator : simtk.openmm.Integrator object
227+
Integrator to be modified.
228+
Returns
229+
-------
230+
system : simtk.openmm.System object
231+
The modified System object.
232+
integrator : simtk.openmm.Integrator object
233+
The modified Integrator object.
234+
235+
"""
236+
new_sys = system
237+
238+
def find_force_group(system, force_type):
239+
"""Returns the force group number for the first force of the given type."""
240+
for force in system.getForces():
241+
if isinstance(force, force_type):
242+
return force.getForceGroup()
243+
raise ValueError(f"No force of type {force_type.__name__} found in system.")
244+
245+
steric_group = find_force_group(new_sys, openmm.NonbondedForce)
246+
self.steric_group = steric_group
247+
248+
if self.restraints:
249+
return self.initializeRestraints(new_sys, integrator, config)
250+
251+
return new_sys, integrator
252+
253+
254+
def initializeRestraints(self, system: openmm.System, integrator: openmm.Integrator, config:dict):
255+
"""
256+
Initialize the restraint forces for the system.
257+
258+
Parameters
259+
----------
260+
system : openmm.System
261+
The OpenMM system to be modified
262+
integrator : openmm.Integrator
263+
The current integrator to be replaced with a restrained version
264+
265+
Returns
266+
-------
267+
system : openmm.System
268+
The modified System object.
269+
integrator : openmm.Integrator
270+
The modified Integrator object.
271+
"""
272+
# if self.restrained_receptor_atoms is None:
273+
# self.restrained_receptor_atoms = self.basis_particles
274+
275+
new_sys = system
276+
old_int = integrator
277+
278+
# Get available force group
279+
force_list = new_sys.getForces()
280+
group_list = list(set([force.getForceGroup() for force in force_list]))
281+
group_avail = [j for j in list(range(32)) if j not in group_list]
282+
283+
if len(group_avail) < len(self.restraints):
284+
raise ValueError("Not enough available force groups for all requested restraints.")
285+
286+
self.restraint_groups = {} # Dict to store group for each restraint type
287+
288+
for i, restraint_type in enumerate(sorted(self.restraints)):
289+
self.restraint_groups[restraint_type] = group_avail[i]
290+
291+
292+
# Verify the old integrator has the required attributes
293+
if not hasattr(old_int, '_alchemical_functions'):
294+
raise AttributeError("Old integrator missing _alchemical_functions")
295+
296+
# Get system parameters from old integrator
297+
old_int._system_parameters = {system_parameter for system_parameter in old_int._alchemical_functions.keys()}
298+
# Extract kwargs for the new integrator
299+
integrator_kwargs = config or {}
300+
301+
# Get integrator kwargs if available, otherwise use defaults
302+
# Create new integrator with restraints
303+
304+
if self.old_restraint:
305+
new_int = AlchemicalExternalRestrainedLangevinIntegrator(
306+
restraint_group=set(self.restraint_groups.values()),
307+
lambda_restraints=self.lambda_restraints,
308+
alchemical_functions = old_int._alchemical_functions,
309+
nsteps_neq=integrator_kwargs['nstepsNC'],
310+
nprop=integrator_kwargs['nprop'],
311+
prop_lambda=integrator_kwargs['propLambda'],
312+
splitting=integrator_kwargs['splitting'],)
313+
314+
else:
315+
new_int = AlchemicalExternalLangevinIntegrator(
316+
restraint_group=set(self.restraint_groups.values()),
317+
lambda_restraints=self.lambda_restraints,
318+
alchemical_functions = old_int._alchemical_functions,
319+
nsteps_neq=integrator_kwargs['nstepsNC'],
320+
nprop=integrator_kwargs['nprop'],
321+
prop_lambda=integrator_kwargs['propLambda'],
322+
splitting=integrator_kwargs['splitting'])
323+
#**old_int.int_kwargs)
324+
325+
new_int.reset()
326+
327+
# Verify we have the required trajectory data
328+
if not hasattr(self, 'binding_mode_traj') or len(self.binding_mode_traj) == 0:
329+
raise ValueError("No binding mode trajectory available for restraints")
330+
331+
for index, pose in enumerate(self.binding_mode_traj):
332+
333+
# Cache the positions once
334+
pose_nm = numpy.array(pose.openmm_positions(0).value_in_unit(unit.nanometers))
335+
336+
pose_pos = pose_nm[self.atom_indices]
337+
pose_allpos = pose_nm * unit.nanometers
338+
339+
# Optional: update just the ligand portion of a shared template
340+
new_pos = pose_nm.copy()
341+
new_pos[self.atom_indices] = pose_pos
342+
new_pos = new_pos * unit.nanometers
343+
344+
if 'boresch' in self.restraints:
345+
# Only validate force constants, allow atoms to be None
346+
if not all(x is not None for x in [self.K_r, self.K_angle]):
347+
raise ValueError("Missing required force constants for Boresch restraints (K_r, K_angle)")
348+
349+
350+
new_sys = add_boresch_restraints(sys=new_sys, struct=self.structure, pos=pose_allpos, ligand_atoms=self.atom_indices,
351+
pose_num=index, force_group=self.restraint_groups['boresch'],
352+
restrained_receptor_atoms=self.restrained_receptor_atoms, restrained_ligand_atoms=self.restrained_ligand_atoms,
353+
K_r=self.K_r, K_angle=self.K_angle, K_RMSD=self.K_RMSD, RMSD0=self.RMSD0,
354+
K_com=self.K_com)
355+
356+
if 'boresch' not in self.restraints and 'rmsd' not in self.restraints:
357+
raise ValueError(f'Invalid restraint type: {self.restraints}')
358+
359+
return new_sys, new_int
360+
211361
def getAtomIndices(self, structure, resname):
212362
"""
213363
Get atom indices of a ligand from ParmEd Structure.

0 commit comments

Comments
 (0)