7
7
import numpy as np
8
8
import pandas as pd
9
9
from openmm import app
10
-
10
+ import MDAnalysis as mda
11
+ from MDAnalysis .analysis .rms import rmsd
11
12
from folding .base .evaluation import BaseEvaluator
12
13
from folding .base .simulation import OpenMMSimulation
13
14
from folding .utils import constants as c
@@ -106,10 +107,24 @@ def process_md_output(self) -> bool:
106
107
)
107
108
108
109
try :
110
+
111
+ self .trajectory_path = os .path .join (
112
+ self .miner_data_directory , "trajectory.dcd"
113
+ )
114
+
115
+ # download the trajectory from s3
116
+ self .s3_handler .get (self .trajectory_s3_path , self .trajectory_path )
117
+
118
+ # check if file exists
119
+ if not os .path .exists (self .trajectory_path ):
120
+ logger .error (
121
+ f"Trajectory file { self .trajectory_path } does not exist... Skipping!"
122
+ )
123
+ return False
124
+
109
125
# NOTE: The seed written in the self.system_config is not used here
110
126
# because the miner could have used something different and we want to
111
127
# make sure that we are using the correct seed.
112
-
113
128
logger .info (
114
129
f"Recreating miner { self .hotkey_alias } simulation in state: { self .current_state } "
115
130
)
@@ -129,12 +144,6 @@ def process_md_output(self) -> bool:
129
144
self .log_file_path = os .path .join (
130
145
self .miner_data_directory , self .md_outputs_exts ["log" ]
131
146
)
132
- self .trajectory_path = os .path .join (
133
- self .miner_data_directory , "trajectory.dcd"
134
- )
135
-
136
- # download the trajectory from s3
137
- self .s3_handler .get (self .trajectory_s3_path , self .trajectory_path )
138
147
139
148
simulation .loadCheckpoint (checkpoint_path )
140
149
@@ -210,7 +219,7 @@ def process_md_output(self) -> bool:
210
219
f"Miner { self .hotkey_alias } has modified the system in unintended ways... Skipping!"
211
220
)
212
221
self .number_of_checkpoints = (
213
- int (self .log_file [ '#"Step"' ]. iloc [ - 1 ] / 10000 ) - 1
222
+ int (self .cpt_step // self . system_config . save_interval_checkpoint ) - 1
214
223
)
215
224
if self .number_of_checkpoints < c .MAX_CHECKPOINTS_TO_VALIDATE :
216
225
raise ValidationError (
@@ -338,6 +347,7 @@ async def is_run_valid(
338
347
result ,
339
348
) = self .is_checkpoint_valid (
340
349
checkpoint_path = self .checkpoint_path ,
350
+ current_cpt_step = self .cpt_step ,
341
351
steps_to_run = c .MAX_SIMULATION_STEPS_FOR_EVALUATION ,
342
352
checkpoint_num = "final" ,
343
353
)
@@ -395,13 +405,18 @@ async def is_run_valid(
395
405
with open (temp_checkpoint_path , "wb" ) as f :
396
406
f .write (checkpoint_data )
397
407
408
+ cpt_step = (
409
+ int (checkpoint_num ) + 1
410
+ ) * self .system_config .save_interval_checkpoint
411
+
398
412
(
399
413
is_valid ,
400
414
checked_energies ,
401
415
miner_energies ,
402
416
result ,
403
417
) = self .is_checkpoint_valid (
404
418
checkpoint_path = temp_checkpoint_path ,
419
+ current_cpt_step = cpt_step ,
405
420
steps_to_run = c .INTERMEDIATE_CHECKPOINT_STEPS ,
406
421
checkpoint_num = checkpoint_num ,
407
422
)
@@ -544,9 +559,45 @@ def get_miner_log_file_energies(
544
559
545
560
return miner_energies
546
561
562
+ def calculate_rmsd (
563
+ self , miner_trajectory , validator_trajectory , start_frame , end_frame
564
+ ):
565
+ """Calculate the RMSD between the miner and validator trajectories for every frame.
566
+
567
+ Args:
568
+ miner_trajectory: MDAnalysis Universe for miner trajectory
569
+ validator_trajectory: MDAnalysis Universe for validator trajectory
570
+
571
+ Returns:
572
+ list[float]: List of RMSD values for each frame
573
+ """
574
+ # Get backbone atoms for both trajectories
575
+ miner_bb = miner_trajectory .select_atoms ("backbone" )
576
+ validator_bb = validator_trajectory .select_atoms ("backbone" )
577
+
578
+ rmsds = []
579
+
580
+ # Iterate through frames in both trajectories
581
+ for miner_ts , validator_ts in zip (
582
+ miner_trajectory .trajectory [start_frame :end_frame ],
583
+ validator_trajectory .trajectory ,
584
+ ):
585
+ # Get coordinates for current frame
586
+ miner_positions = miner_bb .positions .copy ()
587
+ validator_positions = validator_bb .positions .copy ()
588
+
589
+ # Calculate RMSD for current frame
590
+ rmsd_value = rmsd (
591
+ miner_positions , validator_positions , center = True , superposition = True
592
+ )
593
+ rmsds .append (rmsd_value )
594
+
595
+ return rmsds
596
+
547
597
def is_checkpoint_valid (
548
598
self ,
549
599
checkpoint_path : str ,
600
+ current_cpt_step : int ,
550
601
steps_to_run : int = c .MIN_SIMULATION_STEPS ,
551
602
checkpoint_num : str = "final" ,
552
603
):
@@ -580,7 +631,6 @@ def is_checkpoint_valid(
580
631
581
632
# Load checkpoint
582
633
simulation .loadCheckpoint (checkpoint_path )
583
- current_cpt_step = simulation .currentStep
584
634
585
635
if current_cpt_step + steps_to_run > self .log_step :
586
636
raise ValidationError (message = "simulation-step-out-of-range" )
@@ -620,6 +670,10 @@ def is_checkpoint_valid(
620
670
self .miner_data_directory , f"check_{ checkpoint_num } .log"
621
671
)
622
672
673
+ current_state_trajectory = os .path .join (
674
+ self .miner_data_directory , f"check_{ checkpoint_num } .dcd"
675
+ )
676
+
623
677
simulation , _ = self .md_simulator .create_simulation (
624
678
pdb = load_pdb_file (pdb_file = self .pdb_location ),
625
679
system_config = self .system_config .get_config (),
@@ -639,6 +693,13 @@ def is_checkpoint_valid(
639
693
)
640
694
)
641
695
696
+ simulation .reporters .append (
697
+ app .DCDReporter (
698
+ current_state_trajectory ,
699
+ self .system_config .save_interval_trajectory ,
700
+ )
701
+ )
702
+
642
703
logger .info (
643
704
f"Running { steps_to_run } steps. log_step: { self .log_step } , cpt_step: { current_cpt_step } "
644
705
)
@@ -695,6 +756,29 @@ def is_checkpoint_valid(
695
756
)
696
757
raise ValidationError (message = "anomaly" )
697
758
759
+ # convert steps to frames
760
+ start_frame = (
761
+ current_cpt_step // self .system_config .save_interval_trajectory
762
+ )
763
+ end_frame = max_step // self .system_config .save_interval_trajectory
764
+
765
+ check_universe = mda .Universe (self .pdb_location , current_state_trajectory )
766
+ miner_universe = mda .Universe (self .pdb_location , self .trajectory_path )
767
+
768
+ # Calculate RMSD between the trajectories for each frame
769
+ rmsds = self .calculate_rmsd (
770
+ miner_universe , check_universe , start_frame , end_frame
771
+ )
772
+
773
+ # Get median RMSD value
774
+ median_rmsd = np .median (rmsds )
775
+
776
+ if median_rmsd > 1 :
777
+ logger .warning (
778
+ f"hotkey { self .hotkey_alias } failed trajectory RMSD check for { self .pdb_id } , checkpoint_num: { checkpoint_num } , with median RMSD: { median_rmsd } ... Skipping!"
779
+ )
780
+ raise ValidationError (message = "trajectory-rmsd" )
781
+
698
782
# Save the intermediate or final pdb file if the run is valid
699
783
positions = simulation .context .getState (getPositions = True ).getPositions ()
700
784
topology = simulation .topology
0 commit comments