Skip to content

Commit 149fb0c

Browse files
authored
Merge pull request #409 from macrocosm-os/staging
Staging
2 parents 7d2185c + 827bca6 commit 149fb0c

File tree

14 files changed

+1183
-530
lines changed

14 files changed

+1183
-530
lines changed

folding/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "2.3.4"
1+
__version__ = "2.3.5"
22
version_split = __version__.split(".")
33
__spec_version__ = (
44
(10000 * int(version_split[0]))

folding/miners/folding_miner.py

+16-9
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@
2020
from folding.base.miner import BaseMinerNeuron
2121
from folding.base.simulation import OpenMMSimulation
2222
from folding.protocol import JobSubmissionSynapse
23-
from folding.utils.reporters import ExitFileReporter, LastTwoCheckpointsReporter
23+
from folding.utils.reporters import (
24+
ExitFileReporter,
25+
LastTwoCheckpointsReporter,
26+
ProteinStructureReporter,
27+
)
2428
from folding.utils.ops import (
2529
check_if_directory_exists,
2630
get_tracebacks,
@@ -941,21 +945,24 @@ def configure_commands(
941945
reportInterval=self.CHECKPOINT_INTERVAL,
942946
)
943947
)
944-
simulation.reporters.append(
945-
app.StateDataReporter(
946-
file=f"{self.output_dir}/{state}.log",
947-
reportInterval=self.STATE_DATA_REPORTER_INTERVAL,
948-
step=True,
949-
potentialEnergy=True,
950-
)
951-
)
948+
952949
simulation.reporters.append(
953950
ExitFileReporter(
954951
filename=f"{self.output_dir}/{state}",
955952
reportInterval=self.EXIT_REPORTER_INTERVAL,
956953
file_prefix=state,
957954
)
958955
)
956+
simulation.reporters.append(
957+
ProteinStructureReporter(
958+
file=f"{self.output_dir}/{state}.log",
959+
reportInterval=self.STATE_DATA_REPORTER_INTERVAL,
960+
step=True,
961+
potentialEnergy=True,
962+
reference_pdb=os.path.join(self.output_dir, f"{self.pdb_id}.pdb"),
963+
speed=True,
964+
)
965+
)
959966
state_commands[state] = simulation
960967

961968
return state_commands

folding/organic/organic.py

+85-14
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from typing import Optional
12
import uuid
23
import time
34
import json
4-
from fastapi import APIRouter, Request, Depends, HTTPException
5+
import tempfile
6+
from fastapi import APIRouter, Request, Depends, HTTPException, UploadFile, File, Form
57
from folding_api.schemas import EpistulaHeaders
68
from folding_api.schemas import FoldingParams
79
from folding.utils.logging import logger
@@ -10,22 +12,14 @@
1012
router = APIRouter()
1113

1214

13-
@router.post("/organic")
14-
async def organic(
15+
def verify_organic_request(
1516
request: Request,
1617
job: FoldingParams,
17-
epistula_headers: EpistulaHeaders = Depends(EpistulaHeaders),
18-
):
18+
epistula_headers: EpistulaHeaders,
19+
) -> None:
1920
"""
20-
This endpoint is used to receive organic requests. Returns success message with the job id.
21-
Args:
22-
request: Request
23-
job: FoldingParams
24-
epistula_headers: EpistulaHeaders
25-
Returns:
26-
dict[str, str]: dict with the job id.
21+
Verify the organic request signature and whitelist.
2722
"""
28-
2923
body_bytes = json.dumps(job.model_dump(), default=str, sort_keys=True).encode(
3024
"utf-8"
3125
)
@@ -37,7 +31,6 @@ async def organic(
3731
raise HTTPException(status_code=403, detail=str(e))
3832

3933
sender_hotkey = epistula_headers.signed_by
40-
4134
if sender_hotkey not in request.app.state.config.organic_whitelist:
4235
logger.warning(
4336
f"Received organic request from {sender_hotkey}, but {sender_hotkey} is not in the whitelist."
@@ -46,9 +39,87 @@ async def organic(
4639
status_code=403, detail="Forbidden, sender not in whitelist."
4740
)
4841

42+
43+
def get_folding_params(query: str = Form(...)) -> FoldingParams:
44+
"""
45+
Dependency function to parse and validate the query form data.
46+
"""
47+
try:
48+
query_data = json.loads(query)
49+
return FoldingParams(**query_data)
50+
except Exception as e:
51+
raise HTTPException(status_code=400, detail=f"Invalid query data: {str(e)}")
52+
53+
54+
@router.post("/organic")
55+
async def organic(
56+
request: Request,
57+
job: FoldingParams,
58+
epistula_headers: EpistulaHeaders = Depends(EpistulaHeaders),
59+
):
60+
"""
61+
This endpoint is used to receive organic requests for proteins from RCSB or PDBE databases.
62+
Returns success message with the job id.
63+
64+
Args:
65+
request: Request
66+
job: FoldingParams
67+
epistula_headers: EpistulaHeaders
68+
Returns:
69+
dict[str, str]: dict with the job id.
70+
"""
71+
verify_organic_request(request, job, epistula_headers)
72+
4973
folding_params = job.model_dump()
5074
folding_params["job_id"] = str(uuid.uuid4())
75+
5176
logger.info(f"Received organic request: {folding_params}")
5277
request.app.state.validator._organic_queue.add(folding_params)
5378

5479
return {"job_id": folding_params["job_id"]}
80+
81+
82+
@router.post("/organic/upload")
83+
async def organic_with_upload(
84+
request: Request,
85+
job: FoldingParams = Depends(get_folding_params),
86+
pdb_file: UploadFile = File(...),
87+
epistula_headers: EpistulaHeaders = Depends(EpistulaHeaders),
88+
):
89+
"""
90+
This endpoint is used to receive organic requests with custom PDB files.
91+
Returns success message with the job id.
92+
93+
Args:
94+
request: Request
95+
job: FoldingParams
96+
pdb_file: PDB file upload
97+
epistula_headers: EpistulaHeaders
98+
Returns:
99+
dict[str, str]: dict with the job id.
100+
"""
101+
verify_organic_request(request, job, epistula_headers)
102+
103+
folding_params = job.model_dump()
104+
folding_params["job_id"] = str(uuid.uuid4())
105+
106+
# Handle PDB file
107+
try:
108+
# Create a temporary file to store the PDB
109+
with tempfile.NamedTemporaryFile(
110+
mode="wb", suffix=".pdb", delete=False
111+
) as temp_pdb:
112+
content = await pdb_file.read()
113+
temp_pdb.write(content)
114+
temp_pdb_path = temp_pdb.name
115+
116+
# Update folding params with the temporary file path
117+
folding_params["pdb_file_path"] = temp_pdb_path
118+
logger.info(f"Created temporary PDB file at: {temp_pdb_path}")
119+
except Exception as e:
120+
raise HTTPException(status_code=400, detail=f"Invalid PDB file: {str(e)}")
121+
122+
logger.info(f"Received organic request with PDB file: {folding_params}")
123+
request.app.state.validator._organic_queue.add(folding_params)
124+
125+
return {"job_id": folding_params["job_id"]}

folding/store.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def upload_job(
276276
s3_links=event["s3_links"],
277277
priority=event.get("priority", 1),
278278
update_interval=event.get(
279-
"update_interval", random.randint(7200, 14400)
279+
"time_to_live", random.randint(7200, 14400)
280280
), # between 2 hours and 4 hours in seconds
281281
max_time_no_improvement=event.get("max_time_no_improvement", 1),
282282
is_organic=event.get("is_organic", False),

folding/utils/config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def add_validator_args(cls, parser):
366366
"--neuron.vpermit_tao_limit",
367367
type=int,
368368
help="The maximum number of TAO allowed to query a validator with a vpermit.",
369-
default=4096,
369+
default=20_000,
370370
)
371371

372372
parser.add_argument(

folding/utils/reporters.py

+139
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
1+
import io
12
import os
3+
import time
24
import openmm.app as app
5+
import MDAnalysis as mda
6+
from MDAnalysis.analysis.rms import rmsd
7+
import numpy as np
8+
9+
310

411

512
class LastTwoCheckpointsReporter(app.CheckpointReporter):
@@ -38,3 +45,135 @@ def report(self, simulation, state):
3845

3946
def finalize(self):
4047
pass
48+
49+
50+
class ProteinStructureReporter(app.StateDataReporter):
51+
def __init__(
52+
self, file, reportInterval, reference_pdb, **kwargs
53+
):
54+
super().__init__(file, reportInterval, **kwargs)
55+
self.reference_universe = mda.Universe(reference_pdb)
56+
self.positions_history = [] # Store positions for RMSF calculation
57+
58+
def report(self, simulation, state):
59+
"""Generate a report.
60+
61+
Parameters
62+
----------
63+
simulation : Simulation
64+
The Simulation to generate a report for
65+
state : State
66+
The current state of the simulation
67+
"""
68+
if not self._hasInitialized:
69+
self._initializeConstants(simulation)
70+
headers = self._constructHeaders()
71+
if not self._append:
72+
print(
73+
'#"%s"' % ('"' + self._separator + '"').join(headers),
74+
file=self._out,
75+
)
76+
try:
77+
self._out.flush()
78+
except AttributeError:
79+
pass
80+
self._initialClockTime = time.time()
81+
self._initialSimulationTime = state.getTime()
82+
self._initialSteps = simulation.currentStep
83+
self._hasInitialized = True
84+
85+
# Check for errors.
86+
self._checkForErrors(simulation, state)
87+
88+
# Store current positions for RMSF calculation
89+
universe = self.create_mda_universe(simulation)
90+
self.positions_history.append(universe.select_atoms("backbone").positions.copy())
91+
92+
# Query for the values
93+
values = self._constructReportValues(simulation, state)
94+
95+
# Write the values.
96+
print(self._separator.join(str(v) for v in values), file=self._out)
97+
try:
98+
self._out.flush()
99+
except AttributeError:
100+
pass
101+
102+
def _constructReportValues(self, simulation, state):
103+
values = super()._constructReportValues(simulation, state)
104+
rmsd = self._calculate_rmsd(self.create_mda_universe(simulation))
105+
rmsf = self._calculate_rmsf()
106+
values.extend([rmsd, rmsf])
107+
return values
108+
109+
def _calculate_rmsd(self, universe):
110+
"""Calculate RMSD between current and reference positions.
111+
112+
Args:
113+
positions (np.ndarray): Current positions
114+
115+
Returns:
116+
float: RMSD
117+
"""
118+
current_positions = universe.select_atoms("backbone").positions.copy()
119+
reference_positions = self.reference_universe.select_atoms("backbone").positions.copy()
120+
rmsd_measure = rmsd(current_positions, reference_positions, center=True)
121+
return rmsd_measure
122+
123+
def _calculate_rmsf(self):
124+
"""Calculate RMSF (Root Mean Square Fluctuation) over time.
125+
126+
Returns:
127+
float: RMSF value in nanometers
128+
"""
129+
if len(self.positions_history) < 2:
130+
return 0.0
131+
132+
# Convert positions history to numpy array
133+
positions_array = np.array(self.positions_history)
134+
135+
# Calculate mean position for each atom
136+
mean_positions = np.mean(positions_array, axis=0)
137+
138+
# Calculate RMSF
139+
squared_diff = np.square(positions_array - mean_positions)
140+
rmsf = np.sqrt(np.mean(squared_diff))
141+
142+
# Keep only the last 1000 frames to prevent memory issues
143+
if len(self.positions_history) > 1000:
144+
self.positions_history = self.positions_history[-1000:]
145+
146+
return rmsf
147+
148+
def _constructHeaders(self):
149+
headers = super()._constructHeaders()
150+
headers.extend(["RMSD", "RMSF"])
151+
return headers
152+
153+
def create_mda_universe(self,simulation):
154+
"""
155+
Create an MDAnalysis Universe from an OpenMM simulation object.
156+
157+
Args:
158+
simulation (openmm.app.Simulation): The OpenMM simulation object
159+
160+
Returns:
161+
mda.Universe: An MDAnalysis Universe containing the current state of the simulation
162+
"""
163+
# Get the current state
164+
state = simulation.context.getState(getPositions=True)
165+
positions = state.getPositions(asNumpy=True)
166+
167+
# Get the topology
168+
topology = simulation.topology
169+
170+
# Create a PDB string from the current state
171+
pdb_string = io.StringIO()
172+
app.PDBFile.writeFile(topology, positions, pdb_string)
173+
pdb_string.seek(0)
174+
175+
# Create MDAnalysis Universe from the PDB string
176+
universe = mda.Universe(pdb_string, format='pdb')
177+
178+
return universe
179+

folding/utils/uids.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def get_random_uids(self, k: int, exclude: List[int] = None) -> torch.LongTensor
6565
return uids
6666

6767

68-
def get_all_miner_uids(metagraph, vpermit_tao_limit, include_serving_in_check: bool = True) -> List[int]:
68+
def get_all_miner_uids(
69+
metagraph, vpermit_tao_limit, include_serving_in_check: bool = True
70+
) -> List[int]:
6971
"""Returns all available miner uids from the metagraph.
7072
Returns:
7173
uids (List): All available miner uids.
@@ -74,10 +76,10 @@ def get_all_miner_uids(metagraph, vpermit_tao_limit, include_serving_in_check: b
7476
candidate_uids = []
7577
for uid in range(metagraph.n.item()):
7678
uid_is_available = check_uid_availability(
77-
metagraph = metagraph,
78-
uid = uid,
79-
vpermit_tao_limit = vpermit_tao_limit,
80-
include_serving_in_check = include_serving_in_check,
79+
metagraph=metagraph,
80+
uid=uid,
81+
vpermit_tao_limit=vpermit_tao_limit,
82+
include_serving_in_check=include_serving_in_check,
8183
)
8284
if uid_is_available:
8385
candidate_uids.append(uid)

0 commit comments

Comments
 (0)