|
1 |
| -import os |
2 |
| -import time |
3 |
| -import glob |
| 1 | +import asyncio |
4 | 2 | import base64
|
| 3 | +import glob |
| 4 | +import os |
5 | 5 | import random
|
6 | 6 | import shutil
|
7 |
| -import asyncio |
8 |
| -import datetime |
9 |
| -from pathlib import Path |
10 |
| -from dataclasses import dataclass |
| 7 | +import time |
11 | 8 | from collections import defaultdict
|
12 |
| -from typing import Dict, List, Literal, Any |
| 9 | +from dataclasses import dataclass |
| 10 | +from pathlib import Path |
| 11 | +from typing import Any, Dict, List, Literal |
13 | 12 |
|
14 | 13 | import numpy as np
|
15 | 14 | import pandas as pd
|
16 | 15 | from openmm import app, unit
|
17 | 16 | from pdbfixer import PDBFixer
|
18 | 17 |
|
19 |
| -from folding.utils.s3_utils import DigitalOceanS3Handler |
20 | 18 | from folding.base.simulation import OpenMMSimulation
|
21 | 19 | from folding.store import Job
|
| 20 | +from folding.utils.logger import logger |
22 | 21 | from folding.utils.opemm_simulation_config import SimulationConfig
|
23 | 22 | from folding.utils.ops import (
|
24 | 23 | OpenMMException,
|
25 | 24 | ValidationError,
|
26 |
| - write_pkl, |
27 |
| - load_pkl, |
28 | 25 | check_and_download_pdbs,
|
29 | 26 | check_if_directory_exists,
|
| 27 | + load_pkl, |
30 | 28 | plot_miner_validator_curves,
|
| 29 | + write_pkl, |
31 | 30 | )
|
32 |
| - |
33 |
| -from folding.utils.logger import logger |
| 31 | +from folding.utils.s3_utils import DigitalOceanS3Handler |
34 | 32 |
|
35 | 33 | ROOT_DIR = Path(__file__).resolve().parents[2]
|
36 | 34 |
|
@@ -537,6 +535,28 @@ def check_masses(self) -> bool:
|
537 | 535 | logger.error(f"Masses for atom {i} do not match. Validator: {v_mass}, Miner: {m_mass}")
|
538 | 536 | return False
|
539 | 537 | return True
|
| 538 | + |
| 539 | + def compare_state_to_cpt(self, state_energies: list, checkpoint_energies: list) -> bool: |
| 540 | + """ |
| 541 | + Check if the state file is the same as the checkpoint file by comparing the median of the first few energy values |
| 542 | + in the simulation created by the checkpoint and the state file respectively. |
| 543 | + """ |
| 544 | + |
| 545 | + WINDOW = 50 |
| 546 | + |
| 547 | + state_energies = np.array(state_energies) |
| 548 | + checkpoint_energies = np.array(checkpoint_energies) |
| 549 | + |
| 550 | + state_median = np.median(state_energies[:WINDOW]) |
| 551 | + checkpoint_median = np.median(checkpoint_energies[:WINDOW]) |
| 552 | + |
| 553 | + percent_diff = abs((state_median - checkpoint_median) / checkpoint_median) * 100 |
| 554 | + |
| 555 | + if percent_diff > self.epsilon: |
| 556 | + return False |
| 557 | + return True |
| 558 | + |
| 559 | + |
540 | 560 |
|
541 | 561 | def is_run_valid(self):
|
542 | 562 | """
|
@@ -575,8 +595,8 @@ def is_run_valid(self):
|
575 | 595 | )
|
576 | 596 | self.simulation.loadState(self.state_xml_path)
|
577 | 597 | state_energies = []
|
578 |
| - for _ in range(100): |
579 |
| - self.simulation.step(100) |
| 598 | + for _ in range(steps_to_run // 10): |
| 599 | + self.simulation.step(10) |
580 | 600 | energy = self.simulation.context.getState(getEnergy=True).getPotentialEnergy()._value
|
581 | 601 | state_energies.append(energy)
|
582 | 602 |
|
@@ -622,6 +642,10 @@ def is_run_valid(self):
|
622 | 642 | if not self.check_gradient(check_energies=check_energies):
|
623 | 643 | logger.warning(f"hotkey {self.hotkey_alias} failed cpt-gradient check for {self.pdb_id}, ... Skipping!")
|
624 | 644 | return False, [], [], "cpt-gradient"
|
| 645 | + |
| 646 | + if not self.compare_state_to_cpt(state_energies=state_energies, checkpoint_energies=check_energies): |
| 647 | + logger.warning(f"hotkey {self.hotkey_alias} failed state-checkpoint comparison for {self.pdb_id}, ... Skipping!") |
| 648 | + return False, [], [], "state-checkpoint" |
625 | 649 |
|
626 | 650 | # calculating absolute percent difference per step
|
627 | 651 | percent_diff = abs(((check_energies - miner_energies) / miner_energies) * 100)
|
|
0 commit comments