Skip to content

Commit ea02034

Browse files
committed
Merge branch 'develop' into file-transfer-enhancements
2 parents 821e942 + 9483720 commit ea02034

File tree

17 files changed

+420
-209
lines changed

17 files changed

+420
-209
lines changed

.github/workflows/license.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ jobs:
6262
with:
6363
requirements: "requirements.txt,tests/requirements.txt,requirements-client.txt"
6464
fail: "Copyleft,Error,Other"
65-
exclude: "^(pylint|aio[-_]*).*"
65+
exclude: "^(pylint|aio[-_]*|pytest-asyncio|typing-extensions).*"
6666
exclude-license: 'Mozilla Public License 2.0 \(MPL 2.0\)'
6767
totals: true
6868
headers: true

CHANGELOG.md

+9-4
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,24 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111

1212
- Allow registering custom file transfer strategies
1313

14+
### Removed
15+
16+
- Removed obsolete endpoint to register sublattice manifest
17+
18+
### Operations
19+
20+
- Whitelisted two packages for license check
21+
1422
### Changed
1523

24+
- Moved in-memory dispatcher state to main DB
1625
- Improved automatic file transfer strategy selection
1726
- HTTP strategy can now upload files too
1827
- Adjusted sublattice logic. The sublattice builder now attempts to
1928
link the sublattice with its parent electron.
2029
- Replaced json sublattice flow with new tarball importer to allow future memory
2130
footprint enhancements
2231

23-
### Removed
24-
25-
- Removed obsolete endpoint to register sublattice manifest
26-
2732
## [0.238.0-rc.0] - 2025-03-05
2833

2934
### Authors

covalent_dispatcher/_core/dispatcher.py

+14-16
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from . import runner_ng
3535
from .data_modules import graph as tg_utils
3636
from .data_modules import job_manager as jbmgr
37-
from .dispatcher_modules.caches import _pending_parents, _sorted_task_groups, _unresolved_tasks
37+
from .dispatcher_modules.caches import _task_group_cache, _workflow_run_cache
3838
from .runner_modules.cancel import cancel_tasks
3939

4040
app_log = logger.app_log
@@ -95,7 +95,7 @@ async def _handle_completed_node(dispatch_id: str, node_id: int):
9595
gid = child["task_group_id"]
9696
app_log.debug(f"dispatch {dispatch_id}: parent gid {parent_gid}, child gid {gid}")
9797
if parent_gid != gid:
98-
now_pending = await _pending_parents.decrement(dispatch_id, gid)
98+
now_pending = await _task_group_cache.decrement(dispatch_id, gid)
9999
if now_pending < 1:
100100
app_log.debug(f"Queuing task group {gid} for execution")
101101
next_task_groups.append(gid)
@@ -414,13 +414,12 @@ async def _finalize_dispatch(dispatch_id: str):
414414

415415

416416
async def _initialize_caches(dispatch_id, pending_parents, sorted_task_groups):
417-
for gid, indegree in pending_parents.items():
418-
await _pending_parents.set_pending(dispatch_id, gid, indegree)
419-
420-
for gid, sorted_nodes in sorted_task_groups.items():
421-
await _sorted_task_groups.set_task_group(dispatch_id, gid, sorted_nodes)
417+
for gid in pending_parents:
418+
indegree = pending_parents[gid]
419+
sorted_nodes = sorted_task_groups[gid]
420+
await _task_group_cache.set(dispatch_id, gid, indegree, sorted_nodes)
422421

423-
await _unresolved_tasks.set_unresolved(dispatch_id, 0)
422+
await _workflow_run_cache.set_unresolved(dispatch_id, 0)
424423

425424

426425
async def _submit_initial_tasks(dispatch_id: str):
@@ -442,7 +441,7 @@ async def _submit_initial_tasks(dispatch_id: str):
442441
for gid in initial_groups:
443442
sorted_nodes = sorted_task_groups[gid]
444443
app_log.debug(f"Sorted nodes group group {gid}: {sorted_nodes}")
445-
await _unresolved_tasks.increment(dispatch_id, len(sorted_nodes))
444+
await _workflow_run_cache.increment(dispatch_id, len(sorted_nodes))
446445

447446
for gid in initial_groups:
448447
sorted_nodes = sorted_task_groups[gid]
@@ -469,8 +468,8 @@ async def _handle_node_status_update(dispatch_id, node_id, node_status, detail):
469468
if node_status == RESULT_STATUS.COMPLETED:
470469
next_task_groups = await _handle_completed_node(dispatch_id, node_id)
471470
for gid in next_task_groups:
472-
sorted_nodes = await _sorted_task_groups.get_task_group(dispatch_id, gid)
473-
await _unresolved_tasks.increment(dispatch_id, len(sorted_nodes))
471+
sorted_nodes = await _task_group_cache.get_task_group(dispatch_id, gid)
472+
await _workflow_run_cache.increment(dispatch_id, len(sorted_nodes))
474473
await _submit_task_group(dispatch_id, sorted_nodes, gid)
475474

476475
if node_status == RESULT_STATUS.FAILED:
@@ -481,7 +480,7 @@ async def _handle_node_status_update(dispatch_id, node_id, node_status, detail):
481480

482481
# Decrement after any increments to avoid race with
483482
# finalize_dispatch()
484-
await _unresolved_tasks.decrement(dispatch_id)
483+
await _workflow_run_cache.decrement(dispatch_id)
485484

486485

487486
async def _handle_dispatch_exception(dispatch_id: str, ex: Exception) -> RESULT_STATUS:
@@ -532,7 +531,7 @@ async def _handle_event(msg: Dict):
532531
fut.set_result(dispatch_status)
533532
return dispatch_status
534533

535-
unresolved = await _unresolved_tasks.get_unresolved(dispatch_id)
534+
unresolved = await _workflow_run_cache.get_unresolved(dispatch_id)
536535
if unresolved < 1:
537536
app_log.debug("Finalizing dispatch")
538537
try:
@@ -551,7 +550,7 @@ async def _handle_event(msg: Dict):
551550

552551
async def _clear_caches(dispatch_id: str):
553552
"""Clean up all keys in caches."""
554-
await _unresolved_tasks.remove(dispatch_id)
553+
await _workflow_run_cache.remove(dispatch_id)
555554

556555
g_node_link = await tg_utils.get_nodes_links(dispatch_id)
557556
g = nx.readwrite.node_link_graph(g_node_link)
@@ -560,5 +559,4 @@ async def _clear_caches(dispatch_id: str):
560559

561560
for gid in task_groups:
562561
# Clean up no longer referenced keys
563-
await _pending_parents.remove(dispatch_id, gid)
564-
await _sorted_task_groups.remove(dispatch_id, gid)
562+
await _task_group_cache.remove(dispatch_id, gid)

covalent_dispatcher/_core/dispatcher_modules/caches.py

+166-51
Original file line numberDiff line numberDiff line change
@@ -18,84 +18,199 @@
1818
Helper classes for the dispatcher
1919
"""
2020

21-
from .store import _DictStore, _KeyValueBase
21+
import json
22+
import os
23+
import tempfile
2224

25+
from covalent_dispatcher._core.data_modules.utils import run_in_executor
26+
from covalent_dispatcher._dal.base import workflow_db
27+
from covalent_dispatcher._dal.dispatcher_state import TaskGroupState, WorkflowState
28+
from covalent_dispatcher._db.datastore import DataStore
2329

24-
def _pending_parents_key(dispatch_id: str, node_id: int):
25-
return f"pending-parents-{dispatch_id}:{node_id}"
2630

31+
class _WorkflowRunState:
2732

28-
def _unresolved_tasks_key(dispatch_id: str):
29-
return f"unresolved-{dispatch_id}"
33+
def __init__(self, db: DataStore):
34+
self.db = db
3035

31-
32-
def _task_groups_key(dispatch_id: str, task_group_id: int):
33-
return f"task-groups-{dispatch_id}:{task_group_id}"
34-
35-
36-
class _UnresolvedTasksCache:
37-
def __init__(self, store: _KeyValueBase = _DictStore()):
38-
self._store = store
36+
def _get_unresolved(self, dispatch_id: str):
37+
with self.db.session() as session:
38+
records = WorkflowState.get(
39+
session,
40+
fields=["num_unresolved_tasks"],
41+
equality_filters={"dispatch_id": dispatch_id},
42+
membership_filters={},
43+
)
44+
return records[0].num_unresolved_tasks
3945

4046
async def get_unresolved(self, dispatch_id: str):
41-
key = _unresolved_tasks_key(dispatch_id)
42-
return await self._store.get(key)
47+
return await run_in_executor(self._get_unresolved, dispatch_id)
48+
49+
def _set_unresolved(self, dispatch_id: str, val: int):
50+
with self.db.session() as session:
51+
WorkflowState.create(
52+
session,
53+
insert_kwargs={
54+
"dispatch_id": dispatch_id,
55+
"num_unresolved_tasks": val,
56+
},
57+
)
58+
session.commit()
4359

4460
async def set_unresolved(self, dispatch_id: str, val: int):
45-
key = _unresolved_tasks_key(dispatch_id)
46-
await self._store.insert(key, val)
61+
return await run_in_executor(self._set_unresolved, dispatch_id, val)
62+
63+
def _increment(self, dispatch_id: str, interval: int = 1):
64+
with self.db.session() as session:
65+
WorkflowState.incr_bulk(
66+
session=session,
67+
increments={"num_unresolved_tasks": interval},
68+
equality_filters={"dispatch_id": dispatch_id},
69+
membership_filters={},
70+
)
71+
records = WorkflowState.get(
72+
session,
73+
fields=["num_unresolved_tasks"],
74+
equality_filters={"dispatch_id": dispatch_id},
75+
membership_filters={},
76+
)
77+
session.commit()
78+
return records[0].num_unresolved_tasks
4779

4880
async def increment(self, dispatch_id: str, interval: int = 1):
49-
key = _unresolved_tasks_key(dispatch_id)
50-
return await self._store.increment(key, interval)
81+
return await run_in_executor(self._increment, dispatch_id, interval)
82+
83+
def _decrement(self, dispatch_id: str):
84+
with self.db.session() as session:
85+
WorkflowState.incr_bulk(
86+
session=session,
87+
increments={"num_unresolved_tasks": -1},
88+
equality_filters={"dispatch_id": dispatch_id},
89+
membership_filters={},
90+
)
91+
records = WorkflowState.get(
92+
session,
93+
fields=["num_unresolved_tasks"],
94+
equality_filters={"dispatch_id": dispatch_id},
95+
membership_filters={},
96+
)
97+
session.commit()
98+
return records[0].num_unresolved_tasks
5199

52100
async def decrement(self, dispatch_id: str):
53-
key = _unresolved_tasks_key(dispatch_id)
54-
return await self._store.increment(key, -1)
101+
return await run_in_executor(self._decrement, dispatch_id)
102+
103+
def _remove(self, dispatch_id: str):
104+
with self.db.session() as session:
105+
WorkflowState.delete_bulk(
106+
session=session,
107+
equality_filters={"dispatch_id": dispatch_id},
108+
membership_filters={},
109+
)
110+
session.commit()
55111

56112
async def remove(self, dispatch_id: str):
57-
key = _unresolved_tasks_key(dispatch_id)
58-
await self._store.remove(key)
113+
await run_in_executor(self._remove, dispatch_id)
59114

60115

61-
class _PendingParentsCache:
62-
def __init__(self, store: _KeyValueBase = _DictStore()):
63-
self._store = store
116+
class TaskGroupRunState:
64117

65-
async def get_pending(self, dispatch_id: str, task_group_id: int):
66-
key = _pending_parents_key(dispatch_id, task_group_id)
67-
return await self._store.get(key)
118+
def __init__(self, db):
119+
self.db = db
120+
121+
def _get_pending(self, dispatch_id: str, task_group_id: int):
122+
with self.db.session() as session:
123+
records = TaskGroupState.get(
124+
session=session,
125+
fields=["num_pending_parents"],
126+
equality_filters={"dispatch_id": dispatch_id, "task_group_id": task_group_id},
127+
membership_filters={},
128+
)
129+
return records[0].num_pending_parents
68130

69-
async def set_pending(self, dispatch_id: str, task_group_id: int, val: int):
70-
key = _pending_parents_key(dispatch_id, task_group_id)
71-
await self._store.insert(key, val)
131+
async def get_pending(self, dispatch_id: str, task_group_id: int):
132+
return await run_in_executor(self._get_pending, dispatch_id, task_group_id)
133+
134+
def _set(self, dispatch_id: str, task_group_id: int, num_pending: int, sorted_nodes):
135+
with self.db.session() as session:
136+
TaskGroupState.create(
137+
session=session,
138+
insert_kwargs={
139+
"dispatch_id": dispatch_id,
140+
"task_group_id": task_group_id,
141+
"num_pending_parents": num_pending,
142+
"sorted_tasks": json.dumps(sorted_nodes),
143+
},
144+
)
145+
session.commit()
146+
147+
async def set(self, dispatch_id: str, task_group_id: int, num_pending: int, sorted_nodes):
148+
return await run_in_executor(
149+
self._set, dispatch_id, task_group_id, num_pending, sorted_nodes
150+
)
151+
152+
def _decrement(self, dispatch_id: str, task_group_id):
153+
with self.db.session() as session:
154+
TaskGroupState.incr_bulk(
155+
session=session,
156+
increments={"num_pending_parents": -1},
157+
equality_filters={"dispatch_id": dispatch_id, "task_group_id": task_group_id},
158+
membership_filters={},
159+
)
160+
records = TaskGroupState.get(
161+
session,
162+
fields=["num_pending_parents"],
163+
equality_filters={"dispatch_id": dispatch_id, "task_group_id": task_group_id},
164+
membership_filters={},
165+
)
166+
session.commit()
167+
return records[0].num_pending_parents
72168

73169
async def decrement(self, dispatch_id: str, task_group_id: int):
74-
key = _pending_parents_key(dispatch_id, task_group_id)
75-
return await self._store.increment(key, -1)
170+
return await run_in_executor(self._decrement, dispatch_id, task_group_id)
76171

77172
async def remove(self, dispatch_id: str, task_group_id: int):
78-
key = _pending_parents_key(dispatch_id, task_group_id)
79-
await self._store.remove(key)
80-
81-
82-
class _SortedTaskGroups:
83-
def __init__(self, store: _KeyValueBase = _DictStore()):
84-
self._store = store
173+
pass
174+
175+
def _get_task_group(self, dispatch_id: str, task_group_id: int):
176+
with self.db.session() as session:
177+
records = TaskGroupState.get(
178+
session=session,
179+
fields=["sorted_tasks"],
180+
equality_filters={"dispatch_id": dispatch_id, "task_group_id": task_group_id},
181+
membership_filters={},
182+
)
183+
return json.loads(records[0].sorted_tasks)
85184

86185
async def get_task_group(self, dispatch_id: str, task_group_id: int):
87-
key = _task_groups_key(dispatch_id, task_group_id)
88-
return await self._store.get(key)
186+
return await run_in_executor(self._get_task_group, dispatch_id, task_group_id)
89187

90-
async def set_task_group(self, dispatch_id: str, task_group_id: int, sorted_nodes: list):
91-
key = _task_groups_key(dispatch_id, task_group_id)
92-
await self._store.insert(key, sorted_nodes)
188+
def _remove(self, dispatch_id: str, task_group_id: int):
189+
with self.db.session() as session:
190+
TaskGroupState.delete_bulk(
191+
session=session,
192+
equality_filters={"dispatch_id": dispatch_id, "task_group_id": task_group_id},
193+
membership_filters={},
194+
)
195+
session.commit()
93196

94197
async def remove(self, dispatch_id: str, task_group_id: int):
95-
key = _task_groups_key(dispatch_id, task_group_id)
96-
await self._store.remove(key)
198+
await run_in_executor(self._remove, dispatch_id, task_group_id)
199+
200+
201+
# Default to tmpfs backed file
202+
cache_db_file = tempfile.NamedTemporaryFile(
203+
mode="w+b", prefix="covalent-dispatcher-cache-", suffix=".db"
204+
)
205+
cache_db_URL = os.environ.get("COVALENT_CACHE_DB_URL", f"sqlite+pysqlite:///{cache_db_file.name}")
206+
initialize_db = True
207+
208+
# If we want to store dispatcher state in the main DB, let the alembic migrations
209+
# create the tables
210+
if cache_db_URL == workflow_db.db_URL:
211+
initialize_db = False
97212

213+
cache_db = DataStore(db_URL=cache_db_URL, initialize_db=initialize_db)
98214

99-
_pending_parents = _PendingParentsCache()
100-
_unresolved_tasks = _UnresolvedTasksCache()
101-
_sorted_task_groups = _SortedTaskGroups()
215+
_task_group_cache = TaskGroupRunState(db=cache_db)
216+
_workflow_run_cache = _WorkflowRunState(db=cache_db)

0 commit comments

Comments
 (0)