1
- import json
2
1
import os
2
+ import json
3
+ import string
3
4
import random
4
5
import sqlite3
5
- import string
6
- from dataclasses import asdict , dataclass
7
- from datetime import datetime
6
+ import requests
8
7
from queue import Queue
9
- from typing import List
8
+ from typing import Dict , List
9
+
10
+ from datetime import datetime , timezone
11
+ from dataclasses import asdict , dataclass
10
12
11
13
import numpy as np
12
14
import pandas as pd
13
- import requests
15
+
14
16
from atom .epistula .epistula import Epistula
17
+ from folding .utils .epistula_utils import get_epistula_body
15
18
16
19
DB_DIR = os .path .join (os .path .dirname (__file__ ), "db" )
17
20
@@ -53,7 +56,10 @@ def init_db(self):
53
56
water TEXT,
54
57
epsilon REAL,
55
58
system_kwargs TEXT,
56
- min_updates INTEGER
59
+ min_updates INTEGER,
60
+ job_id TEXT,
61
+ s3_links TEXT,
62
+ best_cpt_links TEXT
57
63
)
58
64
"""
59
65
)
@@ -67,9 +73,7 @@ def _row_to_job(self, row) -> "Job":
67
73
# Convert stored JSON strings back to Python objects
68
74
data ["hotkeys" ] = json .loads (data ["hotkeys" ])
69
75
data ["event" ] = json .loads (data ["event" ]) if data ["event" ] else None
70
- data ["system_kwargs" ] = (
71
- json .loads (data ["system_kwargs" ]) if data ["system_kwargs" ] else None
72
- )
76
+ data ["system_kwargs" ] = json .loads (data ["system_kwargs" ]) if data ["system_kwargs" ] else None
73
77
74
78
# Convert timestamps
75
79
for field in ["created_at" , "updated_at" , "best_loss_at" ]:
@@ -80,9 +84,7 @@ def _row_to_job(self, row) -> "Job":
80
84
81
85
# Convert intervals
82
86
data ["update_interval" ] = pd .Timedelta (seconds = data ["update_interval" ])
83
- data ["max_time_no_improvement" ] = pd .Timedelta (
84
- seconds = data ["max_time_no_improvement" ]
85
- )
87
+ data ["max_time_no_improvement" ] = pd .Timedelta (seconds = data ["max_time_no_improvement" ])
86
88
87
89
# Convert boolean
88
90
data ["active" ] = bool (data ["active" ])
@@ -93,12 +95,13 @@ def _job_to_dict(self, job: "Job") -> dict:
93
95
"""Convert a Job object to a dictionary for database storage."""
94
96
data = job .to_dict ()
95
97
96
- # Convert Python objects to JSON strings
97
- data ["hotkeys" ] = json .dumps (data ["hotkeys" ])
98
- data ["event" ] = json .dumps (data ["event" ]) if data ["event" ] else None
99
- data ["system_kwargs" ] = (
100
- json .dumps (data ["system_kwargs" ]) if data ["system_kwargs" ] else None
101
- )
98
+ # Convert Python list or dict objects to JSON strings for sqlite
99
+ data_to_update = {}
100
+ for k , v in data .items ():
101
+ if isinstance (v , (list ,dict )):
102
+ data_to_update [k ] = json .dumps (v )
103
+
104
+ data .update (data_to_update )
102
105
103
106
# Convert timestamps to strings
104
107
for field in ["created_at" , "updated_at" , "best_loss_at" ]:
@@ -109,9 +112,7 @@ def _job_to_dict(self, job: "Job") -> dict:
109
112
110
113
# Convert intervals to seconds
111
114
data ["update_interval" ] = int (data ["update_interval" ].total_seconds ())
112
- data ["max_time_no_improvement" ] = int (
113
- data ["max_time_no_improvement" ].total_seconds ()
114
- )
115
+ data ["max_time_no_improvement" ] = int (data ["max_time_no_improvement" ].total_seconds ())
115
116
116
117
# Convert boolean to integer
117
118
data ["active" ] = int (data ["active" ])
@@ -154,6 +155,7 @@ def insert(
154
155
hotkeys : List [str ],
155
156
epsilon : float ,
156
157
system_kwargs : dict ,
158
+ job_id : str ,
157
159
** kwargs ,
158
160
):
159
161
"""Insert a new job into the database."""
@@ -176,6 +178,7 @@ def insert(
176
178
updated_at = pd .Timestamp .now ().floor ("s" ),
177
179
epsilon = epsilon ,
178
180
system_kwargs = system_kwargs ,
181
+ job_id = job_id ,
179
182
** kwargs ,
180
183
)
181
184
@@ -202,6 +205,34 @@ def update(self, job: "Job"):
202
205
203
206
cur .execute (query , list (data .values ()) + [pdb ])
204
207
208
+ def update_gjp_job (self , job : "Job" , gjp_address : str , keypair , job_id : str ):
209
+ """
210
+ Updates a GJP job with the given parameters.
211
+ Args:
212
+ job (Job): The job object containing job details.
213
+ gjp_address (str): The address of the GJP server.
214
+ keypair (Keypair): The keypair for authentication.
215
+ job_id (str): The ID of the job to be updated.
216
+ Raises:
217
+ ValueError: If the job update fails (response status code is not 200).
218
+ Returns:
219
+ str: The ID of the updated job.
220
+ """
221
+
222
+ body = get_epistula_body (job = job )
223
+
224
+ body_bytes = self .epistula .create_message_body (body )
225
+ headers = self .epistula .generate_header (hotkey = keypair , body = body_bytes )
226
+
227
+ response = requests .post (
228
+ f"http://{ gjp_address } /jobs/update/{ job_id } " ,
229
+ headers = headers ,
230
+ data = body_bytes ,
231
+ )
232
+ if response .status_code != 200 :
233
+ raise ValueError (f"Failed to upload job: { response .text } " )
234
+ return response .json ()["job_id" ]
235
+
205
236
def get_all_pdbs (self ) -> list :
206
237
"""
207
238
Retrieve all PDB IDs from the job store.
@@ -229,32 +260,55 @@ def upload_job(
229
260
water : str ,
230
261
hotkeys : list ,
231
262
system_kwargs : dict ,
232
- hotkey ,
263
+ keypair ,
233
264
gjp_address : str ,
265
+ epsilon : float ,
266
+ s3_links : Dict [str , str ],
267
+ ** kwargs ,
234
268
):
235
- """Upload a job to the database."""
236
-
237
- body = {
238
- "pdb_id" : pdb ,
239
- "hotkeys" : hotkeys ,
240
- "system_config" : {
241
- "ff" : ff ,
242
- "box" : box ,
243
- "water" : water ,
244
- "system_kwargs" : system_kwargs ,
245
- },
246
- "em_s3_link" : "s3://path/to/em" ,
247
- "priority" : 1 ,
248
- "organic" : False
249
- }
250
- body_bytes = self .epistula .create_message_body (body )
251
- headers = self .epistula .generate_header (hotkey = hotkey , body = body_bytes )
269
+ """
270
+ Upload a job to the global job pool database.
271
+
272
+ Args:
273
+ pdb (str): The PDB ID of the job.
274
+ ff (str): The force field configuration.
275
+ box (str): The box configuration.
276
+ water (str): The water configuration.
277
+ hotkeys (list): A list of hotkeys.
278
+ system_kwargs (dict): Additional system configuration arguments.
279
+ keypair (Keypair): The keypair for generating headers.
280
+ gjp_address (str): The address of the api server.
281
+ event (dict): Additional event data.
252
282
253
- response = requests .post (
254
- f"http://{ gjp_address } /jobs" , headers = headers , data = body_bytes
283
+ Returns:
284
+ str: The job ID of the uploaded job.
285
+
286
+ Raises:
287
+ ValueError: If the job upload fails.
288
+ """
289
+ job = Job (
290
+ pdb = pdb ,
291
+ ff = ff ,
292
+ box = box ,
293
+ water = water ,
294
+ hotkeys = hotkeys ,
295
+ created_at = pd .Timestamp .now ().floor ("s" ),
296
+ updated_at = pd .Timestamp .now ().floor ("s" ),
297
+ epsilon = epsilon ,
298
+ system_kwargs = system_kwargs ,
299
+ s3_links = s3_links ,
300
+ ** kwargs ,
255
301
)
302
+
303
+ body = get_epistula_body (job = job )
304
+
305
+ body_bytes = self .epistula .create_message_body (body )
306
+ headers = self .epistula .generate_header (hotkey = keypair , body = body_bytes )
307
+
308
+ response = requests .post (f"http://{ gjp_address } /jobs" , headers = headers , data = body_bytes )
256
309
if response .status_code != 200 :
257
310
raise ValueError (f"Failed to upload job: { response .text } " )
311
+ return response .json ()["job_id" ]
258
312
259
313
260
314
# Keep the Job and MockJob classes as they are, they work well with both implementations
@@ -280,6 +334,9 @@ class Job:
280
334
epsilon : float = 5 # percentage.
281
335
event : dict = None
282
336
system_kwargs : dict = None
337
+ job_id : str = None
338
+ s3_links : Dict [str , str ] = None
339
+ best_cpt_links : list = None
283
340
284
341
def to_dict (self ):
285
342
return asdict (self )
@@ -301,25 +358,21 @@ async def update(self, loss: float, hotkey: str, hotkeys: List[str] = None):
301
358
self .updated_at = pd .Timestamp .now ().floor ("s" )
302
359
self .updated_count += 1
303
360
304
- never_updated_better_loss = (
305
- np .isnan (percent_improvement ) and loss < self .best_loss
306
- )
361
+ never_updated_better_loss = np .isnan (percent_improvement ) and loss < self .best_loss
307
362
better_loss = percent_improvement >= self .epsilon
308
363
309
364
if never_updated_better_loss or better_loss :
310
365
self .best_loss = loss
311
366
self .best_loss_at = pd .Timestamp .now ().floor ("s" )
312
367
self .best_hotkey = hotkey
313
368
elif (
314
- pd .Timestamp .now ().floor ("s" ) - self .best_loss_at
315
- > self .max_time_no_improvement
369
+ pd .Timestamp .now ().floor ("s" ) - self .best_loss_at > self .max_time_no_improvement
316
370
and self .updated_count >= self .min_updates
317
371
):
318
372
self .active = False
319
373
elif (
320
374
isinstance (self .best_loss_at , pd ._libs .tslibs .nattype .NaTType )
321
- and pd .Timestamp .now ().floor ("s" ) - self .created_at
322
- > self .max_time_no_improvement
375
+ and pd .Timestamp .now ().floor ("s" ) - self .created_at > self .max_time_no_improvement
323
376
):
324
377
self .active = False
325
378
@@ -339,10 +392,9 @@ def __init__(self, n_hotkeys=5, update_seconds=5, stop_after_seconds=3600 * 24):
339
392
self .box = "cube"
340
393
self .water = "tip3p"
341
394
self .hotkeys = self ._make_hotkeys (n_hotkeys )
342
- self .created_at = (
343
- pd .Timestamp .now ().floor ("s" )
344
- - pd .Timedelta (seconds = random .randint (0 , 3600 * 24 ))
345
- ).floor ("s" )
395
+ self .created_at = (pd .Timestamp .now ().floor ("s" ) - pd .Timedelta (seconds = random .randint (0 , 3600 * 24 ))).floor (
396
+ "s"
397
+ )
346
398
self .updated_at = self .created_at
347
399
self .best_loss = 0
348
400
self .best_hotkey = random .choice (self .hotkeys )
0 commit comments