Skip to content

Commit 6113633

Browse files
authored
Merge pull request #443 from macrocosm-os/staging
Staging
2 parents 5e3846f + 84c3c28 commit 6113633

File tree

8 files changed

+290
-45
lines changed

8 files changed

+290
-45
lines changed

folding/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "2.5.2"
1+
__version__ = "2.6.0"
22
version_split = __version__.split(".")
33
__spec_version__ = (
44
(10000 * int(version_split[0]))

folding/miners/folding_miner.py

+14-13
Original file line numberDiff line numberDiff line change
@@ -58,23 +58,25 @@ def attach_files(
5858
return synapse
5959

6060

61-
async def upload_to_s3(session: aiohttp.ClientSession, presigned_url: dict, file_path: str) -> None:
61+
async def upload_to_s3(presigned_url: dict, file_path: str) -> None:
6262
"""Asynchronously upload a file to S3 using presigned URL"""
6363
try:
6464
start_time = time.time()
6565
data = FormData()
6666
for key, value in presigned_url["fields"].items():
6767
data.add_field(key, value)
68-
68+
6969
with open(file_path, "rb") as f:
7070
data.add_field("file", f, filename="trajectory.dcd")
71-
72-
async with session.post(
73-
presigned_url["url"],
74-
data=data
75-
) as response:
76-
if response.status != 204:
77-
logger.error(f"Failed to upload trajectory to s3: {await response.text()}")
71+
72+
async with aiohttp.ClientSession() as session:
73+
async with session.post(
74+
presigned_url["url"],
75+
data=data
76+
) as response:
77+
if response.status != 204:
78+
logger.error(f"Failed to upload trajectory to s3: {await response.text()}")
79+
7880
except Exception as e:
7981
logger.error(f"Error uploading to S3: {e}")
8082
get_tracebacks()
@@ -114,7 +116,6 @@ def attach_files_to_synapse(
114116
trajectory_path = os.path.join(data_directory, "trajectory.dcd")
115117
if os.path.exists(trajectory_path):
116118
asyncio.create_task(upload_to_s3(
117-
session=aiohttp.ClientSession(),
118119
presigned_url=synapse.presigned_url,
119120
file_path=trajectory_path
120121
))
@@ -921,9 +922,9 @@ def __init__(
921922
}
922923

923924
self.STATES = ["nvt", "npt", "md_0_1"]
924-
self.CHECKPOINT_INTERVAL = 10000
925-
self.TRAJECTORY_INTERVAL = 10000
926-
self.STATE_DATA_REPORTER_INTERVAL = 10
925+
self.CHECKPOINT_INTERVAL = self.system_config.save_interval_checkpoint
926+
self.TRAJECTORY_INTERVAL = self.system_config.save_interval_trajectory
927+
self.STATE_DATA_REPORTER_INTERVAL = self.system_config.save_interval_log
927928
self.EXIT_REPORTER_INTERVAL = 10
928929

929930
def create_empty_file(self, file_path: str):

folding/organic/api.py

+97
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
from folding.base import validator
44
from folding.organic.organic import router as organic_router
55
from folding.utils.logging import logger
6+
import multiprocessing
7+
from multiprocessing.connection import Connection
8+
import pickle
9+
from typing import Optional
610

711
app = FastAPI()
812

@@ -25,3 +29,96 @@ async def start_organic_api(organic_validator, config):
2529
)
2630
server = uvicorn.Server(config)
2731
await server.serve()
32+
33+
34+
def api_process_main(pipe_connection: Connection, config):
35+
"""
36+
Main function to run in the separate API process.
37+
Receives jobs from the API and sends them back to the main process.
38+
39+
Args:
40+
pipe_connection: Connection to communicate with the main process
41+
config: Configuration for the API
42+
"""
43+
from folding.organic.api import app
44+
import uvicorn
45+
from atom.organic_scoring.organic_queue import OrganicQueue
46+
import asyncio
47+
from asyncio import Task
48+
49+
# Create a dummy validator object that will send jobs through the pipe
50+
class PipeOrganicValidator:
51+
def __init__(self, pipe_connection):
52+
self._organic_queue = OrganicQueue()
53+
self._pipe_connection = pipe_connection
54+
self._check_queue_task: Optional[Task] = None
55+
56+
async def check_queue(self):
57+
"""Periodically check the queue and send jobs to the main process"""
58+
while True:
59+
try:
60+
if not self._organic_queue.is_empty():
61+
# Get all items from the queue
62+
items = []
63+
while not self._organic_queue.is_empty():
64+
item = self._organic_queue.sample()
65+
if item:
66+
items.append(item)
67+
68+
# Send items through the pipe
69+
if items:
70+
logger.info(f"Sending {len(items)} jobs to main process")
71+
self._pipe_connection.send(pickle.dumps(items))
72+
except Exception as e:
73+
logger.error(f"Error checking queue: {e}")
74+
await asyncio.sleep(
75+
1
76+
) # Check more frequently than the main process reads
77+
78+
# Set up the API
79+
loop = asyncio.new_event_loop()
80+
asyncio.set_event_loop(loop)
81+
82+
# Create the validator
83+
organic_validator = PipeOrganicValidator(pipe_connection)
84+
85+
# Start the queue checking task
86+
organic_validator._check_queue_task = loop.create_task(
87+
organic_validator.check_queue()
88+
)
89+
90+
# Set up the app state
91+
app.state.validator = organic_validator
92+
app.state.config = config
93+
94+
# Start the API
95+
uvicorn_config = uvicorn.Config(
96+
"folding.organic.api:app",
97+
host="0.0.0.0",
98+
port=config.neuron.organic_api.port,
99+
loop="asyncio",
100+
reload=False,
101+
)
102+
103+
server = uvicorn.Server(uvicorn_config)
104+
loop.run_until_complete(server.serve())
105+
106+
107+
def start_organic_api_in_process(config):
108+
"""
109+
Start the organic API in a separate process and return a pipe connection
110+
to receive jobs from it.
111+
112+
Args:
113+
config: Configuration for the API
114+
115+
Returns:
116+
Connection: Pipe connection to receive jobs from the API process
117+
"""
118+
parent_conn, child_conn = multiprocessing.Pipe()
119+
process = multiprocessing.Process(
120+
target=api_process_main, args=(child_conn, config), daemon=True
121+
)
122+
process.start()
123+
logger.info(f"Started organic API in separate process (PID: {process.pid})")
124+
return parent_conn, process

folding/registries/evaluation_registry.py

+94-10
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import numpy as np
88
import pandas as pd
99
from openmm import app
10-
10+
import MDAnalysis as mda
11+
from MDAnalysis.analysis.rms import rmsd
1112
from folding.base.evaluation import BaseEvaluator
1213
from folding.base.simulation import OpenMMSimulation
1314
from folding.utils import constants as c
@@ -106,10 +107,24 @@ def process_md_output(self) -> bool:
106107
)
107108

108109
try:
110+
111+
self.trajectory_path = os.path.join(
112+
self.miner_data_directory, "trajectory.dcd"
113+
)
114+
115+
# download the trajectory from s3
116+
self.s3_handler.get(self.trajectory_s3_path, self.trajectory_path)
117+
118+
# check if file exists
119+
if not os.path.exists(self.trajectory_path):
120+
logger.error(
121+
f"Trajectory file {self.trajectory_path} does not exist... Skipping!"
122+
)
123+
return False
124+
109125
# NOTE: The seed written in the self.system_config is not used here
110126
# because the miner could have used something different and we want to
111127
# make sure that we are using the correct seed.
112-
113128
logger.info(
114129
f"Recreating miner {self.hotkey_alias} simulation in state: {self.current_state}"
115130
)
@@ -129,12 +144,6 @@ def process_md_output(self) -> bool:
129144
self.log_file_path = os.path.join(
130145
self.miner_data_directory, self.md_outputs_exts["log"]
131146
)
132-
self.trajectory_path = os.path.join(
133-
self.miner_data_directory, "trajectory.dcd"
134-
)
135-
136-
# download the trajectory from s3
137-
self.s3_handler.get(self.trajectory_s3_path, self.trajectory_path)
138147

139148
simulation.loadCheckpoint(checkpoint_path)
140149

@@ -210,7 +219,7 @@ def process_md_output(self) -> bool:
210219
f"Miner {self.hotkey_alias} has modified the system in unintended ways... Skipping!"
211220
)
212221
self.number_of_checkpoints = (
213-
int(self.log_file['#"Step"'].iloc[-1] / 10000) - 1
222+
int(self.cpt_step // self.system_config.save_interval_checkpoint) - 1
214223
)
215224
if self.number_of_checkpoints < c.MAX_CHECKPOINTS_TO_VALIDATE:
216225
raise ValidationError(
@@ -338,6 +347,7 @@ async def is_run_valid(
338347
result,
339348
) = self.is_checkpoint_valid(
340349
checkpoint_path=self.checkpoint_path,
350+
current_cpt_step=self.cpt_step,
341351
steps_to_run=c.MAX_SIMULATION_STEPS_FOR_EVALUATION,
342352
checkpoint_num="final",
343353
)
@@ -395,13 +405,18 @@ async def is_run_valid(
395405
with open(temp_checkpoint_path, "wb") as f:
396406
f.write(checkpoint_data)
397407

408+
cpt_step = (
409+
int(checkpoint_num) + 1
410+
) * self.system_config.save_interval_checkpoint
411+
398412
(
399413
is_valid,
400414
checked_energies,
401415
miner_energies,
402416
result,
403417
) = self.is_checkpoint_valid(
404418
checkpoint_path=temp_checkpoint_path,
419+
current_cpt_step=cpt_step,
405420
steps_to_run=c.INTERMEDIATE_CHECKPOINT_STEPS,
406421
checkpoint_num=checkpoint_num,
407422
)
@@ -544,9 +559,45 @@ def get_miner_log_file_energies(
544559

545560
return miner_energies
546561

562+
def calculate_rmsd(
563+
self, miner_trajectory, validator_trajectory, start_frame, end_frame
564+
):
565+
"""Calculate the RMSD between the miner and validator trajectories for every frame.
566+
567+
Args:
568+
miner_trajectory: MDAnalysis Universe for miner trajectory
569+
validator_trajectory: MDAnalysis Universe for validator trajectory
570+
571+
Returns:
572+
list[float]: List of RMSD values for each frame
573+
"""
574+
# Get backbone atoms for both trajectories
575+
miner_bb = miner_trajectory.select_atoms("backbone")
576+
validator_bb = validator_trajectory.select_atoms("backbone")
577+
578+
rmsds = []
579+
580+
# Iterate through frames in both trajectories
581+
for miner_ts, validator_ts in zip(
582+
miner_trajectory.trajectory[start_frame:end_frame],
583+
validator_trajectory.trajectory,
584+
):
585+
# Get coordinates for current frame
586+
miner_positions = miner_bb.positions.copy()
587+
validator_positions = validator_bb.positions.copy()
588+
589+
# Calculate RMSD for current frame
590+
rmsd_value = rmsd(
591+
miner_positions, validator_positions, center=True, superposition=True
592+
)
593+
rmsds.append(rmsd_value)
594+
595+
return rmsds
596+
547597
def is_checkpoint_valid(
548598
self,
549599
checkpoint_path: str,
600+
current_cpt_step: int,
550601
steps_to_run: int = c.MIN_SIMULATION_STEPS,
551602
checkpoint_num: str = "final",
552603
):
@@ -580,7 +631,6 @@ def is_checkpoint_valid(
580631

581632
# Load checkpoint
582633
simulation.loadCheckpoint(checkpoint_path)
583-
current_cpt_step = simulation.currentStep
584634

585635
if current_cpt_step + steps_to_run > self.log_step:
586636
raise ValidationError(message="simulation-step-out-of-range")
@@ -620,6 +670,10 @@ def is_checkpoint_valid(
620670
self.miner_data_directory, f"check_{checkpoint_num}.log"
621671
)
622672

673+
current_state_trajectory = os.path.join(
674+
self.miner_data_directory, f"check_{checkpoint_num}.dcd"
675+
)
676+
623677
simulation, _ = self.md_simulator.create_simulation(
624678
pdb=load_pdb_file(pdb_file=self.pdb_location),
625679
system_config=self.system_config.get_config(),
@@ -639,6 +693,13 @@ def is_checkpoint_valid(
639693
)
640694
)
641695

696+
simulation.reporters.append(
697+
app.DCDReporter(
698+
current_state_trajectory,
699+
self.system_config.save_interval_trajectory,
700+
)
701+
)
702+
642703
logger.info(
643704
f"Running {steps_to_run} steps. log_step: {self.log_step}, cpt_step: {current_cpt_step}"
644705
)
@@ -695,6 +756,29 @@ def is_checkpoint_valid(
695756
)
696757
raise ValidationError(message="anomaly")
697758

759+
# convert steps to frames
760+
start_frame = (
761+
current_cpt_step // self.system_config.save_interval_trajectory
762+
)
763+
end_frame = max_step // self.system_config.save_interval_trajectory
764+
765+
check_universe = mda.Universe(self.pdb_location, current_state_trajectory)
766+
miner_universe = mda.Universe(self.pdb_location, self.trajectory_path)
767+
768+
# Calculate RMSD between the trajectories for each frame
769+
rmsds = self.calculate_rmsd(
770+
miner_universe, check_universe, start_frame, end_frame
771+
)
772+
773+
# Get median RMSD value
774+
median_rmsd = np.median(rmsds)
775+
776+
if median_rmsd > 1:
777+
logger.warning(
778+
f"hotkey {self.hotkey_alias} failed trajectory RMSD check for {self.pdb_id}, checkpoint_num: {checkpoint_num}, with median RMSD: {median_rmsd} ... Skipping!"
779+
)
780+
raise ValidationError(message="trajectory-rmsd")
781+
698782
# Save the intermediate or final pdb file if the run is valid
699783
positions = simulation.context.getState(getPositions=True).getPositions()
700784
topology = simulation.topology

folding/utils/opemm_simulation_config.py

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class SimulationConfig(BaseModel):
2626
time_step_size: float = 0.002
2727
time_units: str = "picosecond"
2828
save_interval_checkpoint: int = 10000
29+
save_interval_trajectory: int = 100
2930
save_interval_log: int = 10
3031
box_padding: float = 1.0
3132
friction: float = 1.0

0 commit comments

Comments
 (0)