1
1
import os
2
- from typing import Any , Dict , List , Union
3
- import traceback
2
+ import random
3
+ from typing import Any , Dict
4
+
5
+ import bittensor as bt
4
6
5
7
import numpy as np
6
8
import pandas as pd
7
- from openmm import app
8
- import bittensor as bt
9
- import plotly .graph_objects as go
9
+ from openmm import app , unit
10
10
11
11
from folding .base .evaluation import BaseEvaluator
12
12
from folding .base .simulation import OpenMMSimulation
20
20
save_files ,
21
21
save_pdb ,
22
22
write_pkl ,
23
+ check_uniqueness ,
23
24
)
24
25
from folding .utils .opemm_simulation_config import SimulationConfig
25
26
from folding .protocol import IntermediateSubmissionSynapse
@@ -62,6 +63,7 @@ def __init__(
62
63
)
63
64
64
65
self .intermediate_checkpoint_files = {}
66
+ self .miner_reported_energies = {}
65
67
66
68
def process_md_output (self ) -> bool :
67
69
"""Method to process molecular dynamics data from a miner and recreate the simulation.
@@ -282,6 +284,21 @@ def compare_state_to_cpt(
282
284
return False
283
285
return True
284
286
287
+ def select_stratified_checkpoints (
288
+ self , num_checkpoints : int , num_samples : int
289
+ ) -> list [int ]:
290
+ """Selects num_samples checkpoints from num_checkpoints at evenly spaced intervals."""
291
+
292
+ # Create N evenly spaced bin edges, excluding the last edge (final checkpoint)
293
+ edges = np .linspace (0 , num_checkpoints , num_samples + 1 , dtype = int )[:- 1 ]
294
+
295
+ # Sample one checkpoint randomly from each bin
296
+ selected = [
297
+ random .randint (start , max (start , end - 1 ))
298
+ for start , end in zip (edges [:- 1 ], edges [1 :])
299
+ ]
300
+ return selected
301
+
285
302
async def is_run_valid (self , validator = None , job_id = None , axon = None ):
286
303
"""
287
304
Checks if the run is valid by evaluating a set of logical conditions:
@@ -306,7 +323,7 @@ async def is_run_valid(self, validator=None, job_id=None, axon=None):
306
323
miner_energies_dict = {}
307
324
308
325
logger .info (f"Checking if run is valid for { self .hotkey_alias } ..." )
309
- logger .info (f "Checking final checkpoint..." )
326
+ logger .info ("Checking final checkpoint..." )
310
327
# Check the final checkpoint
311
328
(
312
329
is_valid ,
@@ -326,11 +343,10 @@ async def is_run_valid(self, validator=None, job_id=None, axon=None):
326
343
327
344
# Check the intermediate checkpoints
328
345
if validator is not None and job_id is not None and axon is not None :
329
- checkpoint_numbers = np .random .choice (
330
- range (self .number_of_checkpoints ),
331
- size = c .MAX_CHECKPOINTS_TO_VALIDATE ,
332
- replace = False ,
333
- ).tolist ()
346
+ checkpoint_numbers = self .select_stratified_checkpoints (
347
+ num_checkpoints = self .number_of_checkpoints ,
348
+ num_samples = c .MAX_CHECKPOINTS_TO_VALIDATE + 1 , # +1 for Final
349
+ )
334
350
335
351
# Get intermediate checkpoints from the miner
336
352
intermediate_checkpoints = await self .get_intermediate_checkpoints (
@@ -385,11 +401,33 @@ async def is_run_valid(self, validator=None, job_id=None, axon=None):
385
401
if not is_valid :
386
402
return False , checked_energies_dict , miner_energies_dict , result
387
403
404
+ # Check if the miner's checkpoint is similar to the validator's checkpoint.
405
+ miner_reported_energies = []
406
+ checkpoint_length = len (
407
+ self .miner_reported_energies [str (checkpoint_numbers [0 ])]
408
+ )
409
+ for _ , energy in self .miner_reported_energies .items ():
410
+ miner_reported_energies .append (
411
+ energy [:checkpoint_length ]
412
+ ) # final cpt is larger in length.
413
+
414
+ if not check_uniqueness (
415
+ vectors = miner_reported_energies ,
416
+ tol = c .MINER_CHECKPOINT_SIMILARITY_TOLERANCE ,
417
+ ):
418
+ logger .warning ("Miner checkpoints not unique" )
419
+ return (
420
+ False ,
421
+ checked_energies_dict ,
422
+ miner_energies_dict ,
423
+ "miner-checkpoint-similarity" ,
424
+ )
425
+
388
426
return True , checked_energies_dict , miner_energies_dict , "valid"
389
427
390
- except ValidationError as E :
391
- logger .warning (f"{ E } " )
392
- return False , {}, {}, E .message
428
+ except ValidationError as e :
429
+ logger .warning (f"{ e } " )
430
+ return False , {}, {}, e .message
393
431
394
432
return True , checked_energies_dict , miner_energies_dict , "valid"
395
433
@@ -429,6 +467,7 @@ async def validate(self, validator=None, job_id=None, axon=None):
429
467
430
468
# Use the final checkpoint's energy for the score
431
469
if "final" in checked_energies_dict and checked_energies_dict ["final" ]:
470
+ logger .success (f"Hotkey { self .hotkey_alias } passed validation!" )
432
471
final_energies = checked_energies_dict ["final" ]
433
472
# Take the median of the last ENERGY_WINDOW_SIZE values
434
473
median_energy = np .median (final_energies [- c .ENERGY_WINDOW_SIZE :])
@@ -481,6 +520,17 @@ async def get_intermediate_checkpoints(
481
520
def name (self ) -> str :
482
521
return "SyntheticMD"
483
522
523
+ def get_miner_log_file_energies (
524
+ self , start_index : int , end_index : int
525
+ ) -> np .ndarray :
526
+ """Get the energies from the miner log file for a given range of steps."""
527
+ miner_energies : np .ndarray = self .log_file [
528
+ (self .log_file ['#"Step"' ] > start_index )
529
+ & (self .log_file ['#"Step"' ] <= end_index )
530
+ ]["Potential Energy (kJ/mole)" ].values
531
+
532
+ return miner_energies
533
+
484
534
def is_checkpoint_valid (
485
535
self ,
486
536
checkpoint_path : str ,
@@ -544,7 +594,6 @@ def is_checkpoint_valid(
544
594
545
595
try :
546
596
if not self .check_gradient (check_energies = np .array (state_energies )):
547
- logger .warning (f"state energies: { state_energies } " )
548
597
logger .warning (
549
598
f"hotkey { self .hotkey_alias } failed state-gradient check for { self .pdb_id } , checkpoint_num: { checkpoint_num } , ... Skipping!"
550
599
)
@@ -589,10 +638,11 @@ def is_checkpoint_valid(
589
638
590
639
max_step = current_cpt_step + steps_to_run
591
640
592
- miner_energies : np .ndarray = self .log_file [
593
- (self .log_file ['#"Step"' ] > current_cpt_step )
594
- & (self .log_file ['#"Step"' ] <= max_step )
595
- ]["Potential Energy (kJ/mole)" ].values
641
+ miner_energies : np .ndarray = self .get_miner_log_file_energies (
642
+ start_index = current_cpt_step , end_index = max_step
643
+ )
644
+
645
+ self .miner_reported_energies [checkpoint_num ] = miner_energies
596
646
597
647
if len (np .unique (check_energies )) == 1 :
598
648
logger .warning (
@@ -601,7 +651,6 @@ def is_checkpoint_valid(
601
651
raise ValidationError (message = "reprod-energies-identical" )
602
652
603
653
if not self .check_gradient (check_energies = np .array (check_energies )):
604
- logger .warning (f"check_energies: { check_energies } " )
605
654
logger .warning (
606
655
f"hotkey { self .hotkey_alias } failed cpt-gradient check for { self .pdb_id } , checkpoint_num: { checkpoint_num } , ... Skipping!"
607
656
)
@@ -641,9 +690,9 @@ def is_checkpoint_valid(
641
690
642
691
return True , check_energies .tolist (), miner_energies .tolist (), "valid"
643
692
644
- except ValidationError as E :
645
- logger .warning (f"{ E } " )
646
- return False , [], [], E .message
693
+ except ValidationError as e :
694
+ logger .warning (f"{ e } " )
695
+ return False , [], [], e .message
647
696
648
697
649
698
class OrganicMDEvaluator (SyntheticMDEvaluator ):
0 commit comments