Skip to content

Commit c543777

Browse files
Merge pull request #426 from macrocosm-os/staging
Staging
2 parents 1d9bdd6 + 12a984c commit c543777

File tree

12 files changed

+1375
-87
lines changed

12 files changed

+1375
-87
lines changed

folding/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "2.4.1"
1+
__version__ = "2.4.2"
22
version_split = __version__.split(".")
33
__spec_version__ = (
44
(10000 * int(version_split[0]))

folding/registries/evaluation_registry.py

+72-23
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import os
2-
from typing import Any, Dict, List, Union
3-
import traceback
2+
import random
3+
from typing import Any, Dict
4+
5+
import bittensor as bt
46

57
import numpy as np
68
import pandas as pd
7-
from openmm import app
8-
import bittensor as bt
9-
import plotly.graph_objects as go
9+
from openmm import app, unit
1010

1111
from folding.base.evaluation import BaseEvaluator
1212
from folding.base.simulation import OpenMMSimulation
@@ -20,6 +20,7 @@
2020
save_files,
2121
save_pdb,
2222
write_pkl,
23+
check_uniqueness,
2324
)
2425
from folding.utils.opemm_simulation_config import SimulationConfig
2526
from folding.protocol import IntermediateSubmissionSynapse
@@ -62,6 +63,7 @@ def __init__(
6263
)
6364

6465
self.intermediate_checkpoint_files = {}
66+
self.miner_reported_energies = {}
6567

6668
def process_md_output(self) -> bool:
6769
"""Method to process molecular dynamics data from a miner and recreate the simulation.
@@ -282,6 +284,21 @@ def compare_state_to_cpt(
282284
return False
283285
return True
284286

287+
def select_stratified_checkpoints(
288+
self, num_checkpoints: int, num_samples: int
289+
) -> list[int]:
290+
"""Selects num_samples checkpoints from num_checkpoints at evenly spaced intervals."""
291+
292+
# Create N evenly spaced bin edges, excluding the last edge (final checkpoint)
293+
edges = np.linspace(0, num_checkpoints, num_samples + 1, dtype=int)[:-1]
294+
295+
# Sample one checkpoint randomly from each bin
296+
selected = [
297+
random.randint(start, max(start, end - 1))
298+
for start, end in zip(edges[:-1], edges[1:])
299+
]
300+
return selected
301+
285302
async def is_run_valid(self, validator=None, job_id=None, axon=None):
286303
"""
287304
Checks if the run is valid by evaluating a set of logical conditions:
@@ -306,7 +323,7 @@ async def is_run_valid(self, validator=None, job_id=None, axon=None):
306323
miner_energies_dict = {}
307324

308325
logger.info(f"Checking if run is valid for {self.hotkey_alias}...")
309-
logger.info(f"Checking final checkpoint...")
326+
logger.info("Checking final checkpoint...")
310327
# Check the final checkpoint
311328
(
312329
is_valid,
@@ -326,11 +343,10 @@ async def is_run_valid(self, validator=None, job_id=None, axon=None):
326343

327344
# Check the intermediate checkpoints
328345
if validator is not None and job_id is not None and axon is not None:
329-
checkpoint_numbers = np.random.choice(
330-
range(self.number_of_checkpoints),
331-
size=c.MAX_CHECKPOINTS_TO_VALIDATE,
332-
replace=False,
333-
).tolist()
346+
checkpoint_numbers = self.select_stratified_checkpoints(
347+
num_checkpoints=self.number_of_checkpoints,
348+
num_samples=c.MAX_CHECKPOINTS_TO_VALIDATE + 1, # +1 for Final
349+
)
334350

335351
# Get intermediate checkpoints from the miner
336352
intermediate_checkpoints = await self.get_intermediate_checkpoints(
@@ -385,11 +401,33 @@ async def is_run_valid(self, validator=None, job_id=None, axon=None):
385401
if not is_valid:
386402
return False, checked_energies_dict, miner_energies_dict, result
387403

404+
# Check if the miner's checkpoint is similar to the validator's checkpoint.
405+
miner_reported_energies = []
406+
checkpoint_length = len(
407+
self.miner_reported_energies[str(checkpoint_numbers[0])]
408+
)
409+
for _, energy in self.miner_reported_energies.items():
410+
miner_reported_energies.append(
411+
energy[:checkpoint_length]
412+
) # final cpt is larger in length.
413+
414+
if not check_uniqueness(
415+
vectors=miner_reported_energies,
416+
tol=c.MINER_CHECKPOINT_SIMILARITY_TOLERANCE,
417+
):
418+
logger.warning("Miner checkpoints not unique")
419+
return (
420+
False,
421+
checked_energies_dict,
422+
miner_energies_dict,
423+
"miner-checkpoint-similarity",
424+
)
425+
388426
return True, checked_energies_dict, miner_energies_dict, "valid"
389427

390-
except ValidationError as E:
391-
logger.warning(f"{E}")
392-
return False, {}, {}, E.message
428+
except ValidationError as e:
429+
logger.warning(f"{e}")
430+
return False, {}, {}, e.message
393431

394432
return True, checked_energies_dict, miner_energies_dict, "valid"
395433

@@ -429,6 +467,7 @@ async def validate(self, validator=None, job_id=None, axon=None):
429467

430468
# Use the final checkpoint's energy for the score
431469
if "final" in checked_energies_dict and checked_energies_dict["final"]:
470+
logger.success(f"Hotkey {self.hotkey_alias} passed validation!")
432471
final_energies = checked_energies_dict["final"]
433472
# Take the median of the last ENERGY_WINDOW_SIZE values
434473
median_energy = np.median(final_energies[-c.ENERGY_WINDOW_SIZE :])
@@ -481,6 +520,17 @@ async def get_intermediate_checkpoints(
481520
def name(self) -> str:
482521
return "SyntheticMD"
483522

523+
def get_miner_log_file_energies(
524+
self, start_index: int, end_index: int
525+
) -> np.ndarray:
526+
"""Get the energies from the miner log file for a given range of steps."""
527+
miner_energies: np.ndarray = self.log_file[
528+
(self.log_file['#"Step"'] > start_index)
529+
& (self.log_file['#"Step"'] <= end_index)
530+
]["Potential Energy (kJ/mole)"].values
531+
532+
return miner_energies
533+
484534
def is_checkpoint_valid(
485535
self,
486536
checkpoint_path: str,
@@ -544,7 +594,6 @@ def is_checkpoint_valid(
544594

545595
try:
546596
if not self.check_gradient(check_energies=np.array(state_energies)):
547-
logger.warning(f"state energies: {state_energies}")
548597
logger.warning(
549598
f"hotkey {self.hotkey_alias} failed state-gradient check for {self.pdb_id}, checkpoint_num: {checkpoint_num}, ... Skipping!"
550599
)
@@ -589,10 +638,11 @@ def is_checkpoint_valid(
589638

590639
max_step = current_cpt_step + steps_to_run
591640

592-
miner_energies: np.ndarray = self.log_file[
593-
(self.log_file['#"Step"'] > current_cpt_step)
594-
& (self.log_file['#"Step"'] <= max_step)
595-
]["Potential Energy (kJ/mole)"].values
641+
miner_energies: np.ndarray = self.get_miner_log_file_energies(
642+
start_index=current_cpt_step, end_index=max_step
643+
)
644+
645+
self.miner_reported_energies[checkpoint_num] = miner_energies
596646

597647
if len(np.unique(check_energies)) == 1:
598648
logger.warning(
@@ -601,7 +651,6 @@ def is_checkpoint_valid(
601651
raise ValidationError(message="reprod-energies-identical")
602652

603653
if not self.check_gradient(check_energies=np.array(check_energies)):
604-
logger.warning(f"check_energies: {check_energies}")
605654
logger.warning(
606655
f"hotkey {self.hotkey_alias} failed cpt-gradient check for {self.pdb_id}, checkpoint_num: {checkpoint_num}, ... Skipping!"
607656
)
@@ -641,9 +690,9 @@ def is_checkpoint_valid(
641690

642691
return True, check_energies.tolist(), miner_energies.tolist(), "valid"
643692

644-
except ValidationError as E:
645-
logger.warning(f"{E}")
646-
return False, [], [], E.message
693+
except ValidationError as e:
694+
logger.warning(f"{e}")
695+
return False, [], [], e.message
647696

648697

649698
class OrganicMDEvaluator(SyntheticMDEvaluator):

folding/utils/constants.py

+3
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
ENERGY_WINDOW_SIZE = (
1212
10 # Number of steps to compute median/mean energies when comparing
1313
)
14+
MINER_CHECKPOINT_SIMILARITY_TOLERANCE = (
15+
0.05 # Tolerance for cpts to be considered similar. NOT in percent.
16+
)
1417

1518
# MinerRegistry constants
1619
MAX_JOBS_IN_MEMORY = 1000

folding/utils/ops.py

+21
Original file line numberDiff line numberDiff line change
@@ -443,3 +443,24 @@ def write_pdb_file(
443443
positions=positions,
444444
file=open(pdb_location_path, "w"),
445445
)
446+
447+
448+
def are_vectors_too_similar(vec1, vec2, tol=0.01):
449+
"""Check if two vectors are similar within a tolerance."""
450+
if np.array_equal(vec1, vec2):
451+
return True # Identical
452+
if np.allclose(vec1, vec2, rtol=tol, atol=0):
453+
return True # Too close within tolerance
454+
return False
455+
456+
457+
def check_uniqueness(vectors, tol=0.01):
458+
"""Check if all vectors are unique within a tolerance."""
459+
vectors_np = [np.array(v) for v in vectors]
460+
n = len(vectors_np)
461+
462+
for i in range(n):
463+
for j in range(i + 1, n):
464+
if are_vectors_too_similar(vectors_np[i], vectors_np[j], tol):
465+
return False
466+
return True

folding_api/main.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from folding_api.chain import SubtensorService
1111
from folding_api.protein import router
12+
from folding_api.utility_endpoints import router as utility_router
1213
from folding_api.validator_registry import ValidatorRegistry
1314
from folding_api.auth import APIKeyManager, get_api_key, api_key_router
1415
from folding_api.vars import (
@@ -79,6 +80,6 @@ async def lifespan(app: FastAPI):
7980
# Include routes
8081
app.include_router(router, dependencies=[Depends(get_api_key)])
8182
app.include_router(api_key_router) # API key management routes
82-
83+
app.include_router(utility_router) # Utility endpoints
8384
if __name__ == "__main__":
8485
uvicorn.run("main:app", host="0.0.0.0", port=8029)

folding_api/schemas.py

+78
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,81 @@ class APIKeyCreate(BaseModel):
176176

177177
class APIKeyResponse(APIKeyBase):
178178
key: str
179+
180+
181+
class PDB(BaseModel):
182+
pdb_id: str
183+
source: str
184+
185+
186+
class PDBSearchResponse(BaseModel):
187+
"""
188+
Represents a response from a PDB search.
189+
"""
190+
191+
matches: List[PDB] = Field(..., description="List of matching PDB IDs")
192+
total: int = Field(..., description="Total number of matches found")
193+
194+
195+
class PDBInfoResponse(BaseModel):
196+
"""
197+
Represents detailed information about a PDB structure from RCSB.
198+
"""
199+
200+
pdb_id: str = Field(..., description="PDB ID")
201+
molecule_name: Optional[str] = Field(
202+
None, description="Name of the molecule/protein"
203+
)
204+
classification: Optional[str] = Field(None, description="Structural classification")
205+
organism: Optional[str] = Field(None, description="Source organism")
206+
expression_system: Optional[str] = Field(None, description="Expression system used")
207+
208+
209+
class Job(BaseModel):
210+
id: str
211+
type: Literal["organic", "synthetic"]
212+
job_id: str
213+
pdb_id: str
214+
created_at: str
215+
status: Literal["active", "inactive", "failed"]
216+
priority: int
217+
validator_hotkey: str
218+
best_hotkey: str
219+
s3_links: dict[str, str]
220+
221+
222+
class JobPoolResponse(BaseModel):
223+
"""
224+
Represents a response from a job pool.
225+
"""
226+
227+
jobs: List[Job] = Field(..., description="List of jobs")
228+
total: int = Field(..., description="Total number of jobs")
229+
230+
231+
class Miner(BaseModel):
232+
uid: str
233+
hotkey: str
234+
energy: dict
235+
236+
237+
class JobResponse(BaseModel):
238+
pdb_id: str
239+
pdb_file_link: str
240+
classification: Optional[str] = Field(None, description="Structural classification")
241+
expression_system: Optional[str] = Field(None, description="Expression system used")
242+
mutations: Optional[bool] = Field(None, description="Mutations in the PDB")
243+
source: str
244+
temperature: float
245+
friction: float
246+
pressure: float
247+
time_to_live: float
248+
ff: str
249+
water: str
250+
box: str
251+
miners: List[Miner]
252+
status: Literal["active", "inactive", "failed"] = Field(
253+
"inactive", description="Job status"
254+
)
255+
created_at: str = Field("", description="Job creation timestamp")
256+
updated_at: str = Field("", description="Job last update timestamp")

0 commit comments

Comments
 (0)