Skip to content

Commit 927caad

Browse files
authored
fix: use manual db session in background tasks (#33)
* feat: db manual session * fix: preview dataset query * fix: cache dataset * fix: cache dataset * feat: separate shard statistics table * chore: build ui * fix: delete task status if done * fix: unified background task status cache * chore: build ui * fix: python openapi client * fix: shard sync delete * fix: test
1 parent c0c255a commit 927caad

File tree

109 files changed

+3201
-649
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

109 files changed

+3201
-649
lines changed

lavender_data/server/background_worker/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from .background_worker import (
66
TaskStatus,
7-
TaskMetadata,
87
BackgroundWorker,
98
get_background_worker,
109
setup_background_worker,
@@ -19,7 +18,6 @@
1918

2019
__all__ = [
2120
"TaskStatus",
22-
"TaskMetadata",
2321
"BackgroundWorker",
2422
"get_background_worker",
2523
"setup_background_worker",

lavender_data/server/background_worker/background_worker.py

Lines changed: 40 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,6 @@ class TaskStatus(BaseModel):
2121
total: int
2222

2323

24-
class TaskMetadata(BaseModel):
25-
uid: str
26-
name: str
27-
start_time: datetime
28-
kwargs: dict
29-
status: Optional[TaskStatus] = None
30-
31-
3224
def set_task_status(
3325
task_id: str,
3426
status: Optional[str] = None,
@@ -40,21 +32,21 @@ def set_task_status(
4032
if _status is None:
4133
_status = TaskStatus(status="", current=0, total=0)
4234

43-
next(get_cache()).set(
44-
f"task-{task_id}",
35+
next(get_cache()).hset(
36+
f"background-worker:tasks",
37+
task_id,
4538
json.dumps(
4639
{
47-
"status": (status if status is not None else _status.status),
48-
"current": (current if current is not None else _status.current),
40+
"status": status if status is not None else _status.status,
41+
"current": current if current is not None else _status.current,
4942
"total": total if total is not None else _status.total,
5043
}
5144
),
52-
ex=ex,
5345
)
5446

5547

5648
def get_task_status(task_uid: str) -> Optional[TaskStatus]:
57-
status = next(get_cache()).get(f"task-{task_uid}")
49+
status = next(get_cache()).hget(f"background-worker:tasks", task_uid)
5850
if status is None:
5951
return None
6052

@@ -67,7 +59,15 @@ def get_task_status(task_uid: str) -> Optional[TaskStatus]:
6759

6860

6961
def delete_task_status(task_uid: str):
70-
next(get_cache()).delete(f"task-{task_uid}")
62+
next(get_cache()).hdel(f"background-worker:tasks", task_uid)
63+
64+
65+
def all_task_statuses() -> dict[str, TaskStatus]:
66+
tasks = next(get_cache()).hgetall(f"background-worker:tasks")
67+
return {
68+
task_id: TaskStatus.model_validate(json.loads(t))
69+
for task_id, t in tasks.items()
70+
}
7171

7272

7373
class Aborted(Exception):
@@ -129,7 +129,6 @@ def _run_task_no_status(
129129

130130

131131
class TaskItem(NamedTuple):
132-
metadata: TaskMetadata
133132
future: Future
134133
abort_event: threading.Event
135134

@@ -141,40 +140,35 @@ def __init__(self, num_workers: int):
141140

142141
self._process_pool = ProcessPool(self._num_workers)
143142

144-
self._tasks: list[TaskItem] = []
145-
self._tasks_lock = threading.Lock()
146-
147-
self._task_status: dict[str, TaskStatus] = {}
148-
149143
self._executor = ThreadPoolExecutor(self._num_workers)
144+
self._abort_events: dict[str, threading.Event] = {}
145+
self._futures: dict[str, Future] = {}
150146

151147
self._start_cleanup_thread()
152148

153149
def process_pool(self) -> ProcessPool:
154150
return self._process_pool
155151

156152
def _cleanup_tasks(self):
157-
with self._tasks_lock:
158-
for t in self._tasks:
159-
if get_task_status(t.metadata.uid) is None:
160-
self._tasks.remove(t)
153+
for task_id, status in all_task_statuses().items():
154+
if status.status == "completed":
155+
delete_task_status(task_id)
156+
157+
if status.status in ["completed", "aborted", "failed"]:
158+
self._abort_events.pop(task_id, None)
159+
self._futures.pop(task_id, None)
161160

162161
def _start_cleanup_thread(self):
163162
def _cleanup_tasks():
164163
while True:
165-
time.sleep(1)
164+
time.sleep(10)
166165
self._cleanup_tasks()
167166

168167
threading.Thread(target=_cleanup_tasks, daemon=True).start()
169168

170-
def running_tasks(self) -> list[TaskMetadata]:
169+
def list_tasks(self) -> dict[str, TaskStatus]:
171170
self._cleanup_tasks()
172-
with self._tasks_lock:
173-
tasks = [t.metadata for t in self._tasks]
174-
tasks.sort(key=lambda t: t.start_time)
175-
for task in tasks:
176-
task.status = get_task_status(task.uid)
177-
return tasks
171+
return {task_id: status for task_id, status in all_task_statuses().items()}
178172

179173
def get_task_status(self, task_id: str) -> Optional[TaskStatus]:
180174
return get_task_status(task_id)
@@ -204,19 +198,8 @@ def thread_pool_submit(
204198
**kwargs,
205199
)
206200

207-
with self._tasks_lock:
208-
self._tasks.append(
209-
TaskItem(
210-
metadata=TaskMetadata(
211-
uid=task_id,
212-
name=task_name or func.__name__,
213-
start_time=datetime.now(UTC),
214-
kwargs=kwargs,
215-
),
216-
future=future,
217-
abort_event=abort_event,
218-
)
219-
)
201+
self._abort_events[task_id] = abort_event
202+
self._futures[task_id] = future
220203

221204
return task_id
222205

@@ -228,20 +211,20 @@ def process_pool_submit(
228211
return self.process_pool().submit(func, **kwargs)
229212

230213
def abort(self, task_id: str):
231-
with self._tasks_lock:
232-
status = get_task_status(task_id)
233-
if status is not None:
234-
delete_task_status(task_id)
214+
status = get_task_status(task_id)
215+
if status is not None:
216+
delete_task_status(task_id)
217+
218+
if task_id in self._abort_events:
219+
self._abort_events[task_id].set()
235220

236-
task = next((t for t in self._tasks if t.metadata.uid == task_id), None)
237-
if task is not None:
238-
task.abort_event.set()
239-
task.future.cancel()
240-
self._tasks.remove(task)
221+
if task_id in self._futures:
222+
self._futures[task_id].cancel()
241223

242224
def abort_all(self):
243-
for t in self._tasks:
244-
self.abort(t.metadata.uid)
225+
for task_id, status in all_task_statuses().items():
226+
if status.status == "running":
227+
self.abort(task_id)
245228

246229
def shutdown(self):
247230
self._logger.debug("Shutting down background worker")

lavender_data/server/cli/create_api_key.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from sqlmodel import select
55

6-
from lavender_data.server.db import get_session, setup_db
6+
from lavender_data.server.db import db_manual_session, setup_db
77
from lavender_data.server.db.models import ApiKey
88
from lavender_data.server.settings import get_settings
99

@@ -13,18 +13,21 @@ def create_api_key(
1313
expires_at: Optional[datetime] = None,
1414
):
1515
setup_db(get_settings().lavender_data_db_url)
16-
session = next(get_session())
1716

18-
api_key = None
19-
if note:
20-
api_key = session.exec(select(ApiKey).where(ApiKey.note == note)).one_or_none()
17+
with db_manual_session() as session:
18+
api_key = None
19+
if note:
20+
api_key = session.exec(
21+
select(ApiKey).where(ApiKey.note == note)
22+
).one_or_none()
2123

22-
if api_key is None:
23-
api_key = ApiKey(note=note, expires_at=expires_at)
24-
session.add(api_key)
25-
else:
26-
api_key.expires_at = expires_at
24+
if api_key is None:
25+
api_key = ApiKey(note=note, expires_at=expires_at)
26+
session.add(api_key)
27+
else:
28+
api_key.expires_at = expires_at
29+
30+
session.commit()
31+
session.refresh(api_key)
2732

28-
session.commit()
29-
session.refresh(api_key)
3033
return api_key

lavender_data/server/dataset/preview.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,25 @@
22
import time
33
from typing import Any, Union
44

5-
from sqlalchemy.exc import NoResultFound
5+
from sqlmodel import select
6+
from sqlalchemy.orm import selectinload
7+
68
import filetype
79
import hashlib
810
import numpy as np
911
import json
1012

1113
from lavender_data.server.settings import files_dir
12-
from lavender_data.server.db import get_session
13-
from lavender_data.server.db.models import Dataset, Shard
14+
from lavender_data.server.db import db_manual_session
15+
from lavender_data.server.db.models import (
16+
Dataset,
17+
Shard,
18+
Shardset,
19+
DatasetPublic,
20+
ShardPublic,
21+
ShardsetPublic,
22+
DatasetColumnPublic,
23+
)
1424
from lavender_data.server.cache import CacheClient, get_cache
1525
from lavender_data.server.reader import (
1626
get_reader_instance,
@@ -30,8 +40,17 @@
3040
torch = None
3141

3242

43+
class _Shardset(ShardsetPublic):
44+
shards: list[ShardPublic]
45+
columns: list[DatasetColumnPublic]
46+
47+
48+
class _Dataset(DatasetPublic):
49+
shardsets: list[_Shardset]
50+
51+
3352
def _read_dataset(
34-
dataset: Dataset,
53+
dataset: _Dataset,
3554
index: int,
3655
reader: ReaderInstance,
3756
cache: CacheClient,
@@ -165,14 +184,36 @@ def preview_dataset(
165184
offset: int,
166185
limit: int,
167186
) -> list[dict[str, Any]]:
168-
session = next(get_session())
169187
cache = next(get_cache())
170188
reader = get_reader_instance()
171189

172-
try:
173-
dataset = session.get_one(Dataset, dataset_id)
174-
except NoResultFound:
175-
raise ValueError(f"Dataset {dataset_id} not found")
190+
cached_dataset = cache.hget(f"preview:{dataset_id}", "dataset")
191+
if cached_dataset is None:
192+
with db_manual_session() as session:
193+
dataset = session.exec(
194+
select(Dataset)
195+
.where(Dataset.id == dataset_id)
196+
.options(
197+
selectinload(Dataset.shardsets).options(
198+
selectinload(Shardset.columns),
199+
selectinload(Shardset.shards),
200+
)
201+
)
202+
).one()
203+
204+
if dataset is None:
205+
raise ValueError(f"Dataset {dataset_id} not found")
206+
207+
dataset = _Dataset.model_validate(dataset)
208+
cache.hset(f"preview:{dataset_id}", "dataset", dataset.model_dump_json())
209+
for shardset in dataset.shardsets:
210+
cache.hset(
211+
f"preview:{dataset_id}",
212+
f"dataset.shardsets.{shardset.id}",
213+
shardset.model_dump_json(),
214+
)
215+
else:
216+
dataset = _Dataset.model_validate_json(cached_dataset)
176217

177218
samples = []
178219
for index in range(offset, offset + limit):

0 commit comments

Comments
 (0)