|
22 | 22 | from openmm import unit |
23 | 23 | import tempfile |
24 | 24 | import numpy as np |
| 25 | +import openmm |
| 26 | + |
| 27 | +from blues.integrators import AlchemicalExternalLangevinIntegrator, AlchemicalExternalRestrainedLangevinIntegrator |
| 28 | +from blues.restraints import add_boresch_restraints |
25 | 29 |
|
26 | 30 |
|
27 | 31 | try: |
@@ -208,6 +212,152 @@ def __init__(self, structure, resname='LIG', ligand_indices=None, random_state=N |
208 | 212 |
|
209 | 213 | self._calculateProperties() |
210 | 214 |
|
| 215 | + |
| 216 | + def initializeSystem(self, system, integrator, config): |
| 217 | + """ |
| 218 | + Changes the system by adding forces corresponding to restraints (if specified) |
| 219 | + and freeze protein and/or waters, if specified in __init__() |
| 220 | +
|
| 221 | +
|
| 222 | + Parameters |
| 223 | + ---------- |
| 224 | + system : simtk.openmm.System object |
| 225 | + System to be modified. |
| 226 | + integrator : simtk.openmm.Integrator object |
| 227 | + Integrator to be modified. |
| 228 | + Returns |
| 229 | + ------- |
| 230 | + system : simtk.openmm.System object |
| 231 | + The modified System object. |
| 232 | + integrator : simtk.openmm.Integrator object |
| 233 | + The modified Integrator object. |
| 234 | +
|
| 235 | + """ |
| 236 | + new_sys = system |
| 237 | + |
| 238 | + def find_force_group(system, force_type): |
| 239 | + """Returns the force group number for the first force of the given type.""" |
| 240 | + for force in system.getForces(): |
| 241 | + if isinstance(force, force_type): |
| 242 | + return force.getForceGroup() |
| 243 | + raise ValueError(f"No force of type {force_type.__name__} found in system.") |
| 244 | + |
| 245 | + steric_group = find_force_group(new_sys, openmm.NonbondedForce) |
| 246 | + self.steric_group = steric_group |
| 247 | + |
| 248 | + if self.restraints: |
| 249 | + return self.initializeRestraints(new_sys, integrator, config) |
| 250 | + |
| 251 | + return new_sys, integrator |
| 252 | + |
| 253 | + |
| 254 | + def initializeRestraints(self, system: openmm.System, integrator: openmm.Integrator, config:dict): |
| 255 | + """ |
| 256 | + Initialize the restraint forces for the system. |
| 257 | +
|
| 258 | + Parameters |
| 259 | + ---------- |
| 260 | + system : openmm.System |
| 261 | + The OpenMM system to be modified |
| 262 | + integrator : openmm.Integrator |
| 263 | + The current integrator to be replaced with a restrained version |
| 264 | +
|
| 265 | + Returns |
| 266 | + ------- |
| 267 | + system : openmm.System |
| 268 | + The modified System object. |
| 269 | + integrator : openmm.Integrator |
| 270 | + The modified Integrator object. |
| 271 | + """ |
| 272 | + # if self.restrained_receptor_atoms is None: |
| 273 | + # self.restrained_receptor_atoms = self.basis_particles |
| 274 | + |
| 275 | + new_sys = system |
| 276 | + old_int = integrator |
| 277 | + |
| 278 | + # Get available force group |
| 279 | + force_list = new_sys.getForces() |
| 280 | + group_list = list(set([force.getForceGroup() for force in force_list])) |
| 281 | + group_avail = [j for j in list(range(32)) if j not in group_list] |
| 282 | + |
| 283 | + if len(group_avail) < len(self.restraints): |
| 284 | + raise ValueError("Not enough available force groups for all requested restraints.") |
| 285 | + |
| 286 | + self.restraint_groups = {} # Dict to store group for each restraint type |
| 287 | + |
| 288 | + for i, restraint_type in enumerate(sorted(self.restraints)): |
| 289 | + self.restraint_groups[restraint_type] = group_avail[i] |
| 290 | + |
| 291 | + |
| 292 | + # Verify the old integrator has the required attributes |
| 293 | + if not hasattr(old_int, '_alchemical_functions'): |
| 294 | + raise AttributeError("Old integrator missing _alchemical_functions") |
| 295 | + |
| 296 | + # Get system parameters from old integrator |
| 297 | + old_int._system_parameters = {system_parameter for system_parameter in old_int._alchemical_functions.keys()} |
| 298 | + # Extract kwargs for the new integrator |
| 299 | + integrator_kwargs = config or {} |
| 300 | + |
| 301 | + # Get integrator kwargs if available, otherwise use defaults |
| 302 | + # Create new integrator with restraints |
| 303 | + |
| 304 | + if self.old_restraint: |
| 305 | + new_int = AlchemicalExternalRestrainedLangevinIntegrator( |
| 306 | + restraint_group=set(self.restraint_groups.values()), |
| 307 | + lambda_restraints=self.lambda_restraints, |
| 308 | + alchemical_functions = old_int._alchemical_functions, |
| 309 | + nsteps_neq=integrator_kwargs['nstepsNC'], |
| 310 | + nprop=integrator_kwargs['nprop'], |
| 311 | + prop_lambda=integrator_kwargs['propLambda'], |
| 312 | + splitting=integrator_kwargs['splitting'],) |
| 313 | + |
| 314 | + else: |
| 315 | + new_int = AlchemicalExternalLangevinIntegrator( |
| 316 | + restraint_group=set(self.restraint_groups.values()), |
| 317 | + lambda_restraints=self.lambda_restraints, |
| 318 | + alchemical_functions = old_int._alchemical_functions, |
| 319 | + nsteps_neq=integrator_kwargs['nstepsNC'], |
| 320 | + nprop=integrator_kwargs['nprop'], |
| 321 | + prop_lambda=integrator_kwargs['propLambda'], |
| 322 | + splitting=integrator_kwargs['splitting']) |
| 323 | + #**old_int.int_kwargs) |
| 324 | + |
| 325 | + new_int.reset() |
| 326 | + |
| 327 | + # Verify we have the required trajectory data |
| 328 | + if not hasattr(self, 'binding_mode_traj') or len(self.binding_mode_traj) == 0: |
| 329 | + raise ValueError("No binding mode trajectory available for restraints") |
| 330 | + |
| 331 | + for index, pose in enumerate(self.binding_mode_traj): |
| 332 | + |
| 333 | + # Cache the positions once |
| 334 | + pose_nm = numpy.array(pose.openmm_positions(0).value_in_unit(unit.nanometers)) |
| 335 | + |
| 336 | + pose_pos = pose_nm[self.atom_indices] |
| 337 | + pose_allpos = pose_nm * unit.nanometers |
| 338 | + |
| 339 | + # Optional: update just the ligand portion of a shared template |
| 340 | + new_pos = pose_nm.copy() |
| 341 | + new_pos[self.atom_indices] = pose_pos |
| 342 | + new_pos = new_pos * unit.nanometers |
| 343 | + |
| 344 | + if 'boresch' in self.restraints: |
| 345 | + # Only validate force constants, allow atoms to be None |
| 346 | + if not all(x is not None for x in [self.K_r, self.K_angle]): |
| 347 | + raise ValueError("Missing required force constants for Boresch restraints (K_r, K_angle)") |
| 348 | + |
| 349 | + |
| 350 | + new_sys = add_boresch_restraints(sys=new_sys, struct=self.structure, pos=pose_allpos, ligand_atoms=self.atom_indices, |
| 351 | + pose_num=index, force_group=self.restraint_groups['boresch'], |
| 352 | + restrained_receptor_atoms=self.restrained_receptor_atoms, restrained_ligand_atoms=self.restrained_ligand_atoms, |
| 353 | + K_r=self.K_r, K_angle=self.K_angle, K_RMSD=self.K_RMSD, RMSD0=self.RMSD0, |
| 354 | + K_com=self.K_com) |
| 355 | + |
| 356 | + if 'boresch' not in self.restraints and 'rmsd' not in self.restraints: |
| 357 | + raise ValueError(f'Invalid restraint type: {self.restraints}') |
| 358 | + |
| 359 | + return new_sys, new_int |
| 360 | + |
211 | 361 | def getAtomIndices(self, structure, resname): |
212 | 362 | """ |
213 | 363 | Get atom indices of a ligand from ParmEd Structure. |
|
0 commit comments