-
Notifications
You must be signed in to change notification settings - Fork 322
Expand file tree
/
Copy pathdb_models.py
More file actions
155 lines (112 loc) · 5.46 KB
/
db_models.py
File metadata and controls
155 lines (112 loc) · 5.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""Database models for the Tinker API."""
from datetime import datetime, timezone
from enum import Enum
from sqlalchemy import DateTime, event
from sqlalchemy.engine import url as sqlalchemy_url
from sqlmodel import JSON, Field, SQLModel
from skyrl.tinker import types
def enable_sqlite_wal(engine) -> None:
"""Enable WAL mode and busy timeout for SQLite engines.
WAL mode allows concurrent readers with a single writer.
Busy timeout makes SQLite retry internally instead of immediately
raising 'database is locked'.
No-op for non-SQLite engines.
"""
if engine.dialect.name != "sqlite":
return
@event.listens_for(engine, "connect")
def _set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute("PRAGMA busy_timeout=30000")
cursor.close()
def get_async_database_url(db_url: str) -> str:
"""Get the async database URL.
Args:
db_url: Optional database URL to use.
Returns:
Async database URL string for SQLAlchemy.
Raises:
ValueError: If the database scheme is not supported.
"""
parsed_url = sqlalchemy_url.make_url(db_url)
match parsed_url.get_backend_name():
case "sqlite":
async_url = parsed_url.set(drivername="sqlite+aiosqlite")
case "postgresql":
async_url = parsed_url.set(drivername="postgresql+asyncpg")
case _ if "+" in parsed_url.drivername:
# Already has an async driver specified, keep it
async_url = parsed_url
case backend_name:
raise ValueError(f"Unsupported database scheme: {backend_name}")
return async_url.render_as_string(hide_password=False)
class RequestStatus(str, Enum):
"""Status of a request."""
PENDING = "pending"
COMPLETED = "completed"
FAILED = "failed"
class CheckpointStatus(str, Enum):
"""Status of a checkpoint."""
PENDING = "pending"
COMPLETED = "completed"
FAILED = "failed"
# SQLModel table definitions
class ModelDB(SQLModel, table=True):
__tablename__ = "models"
model_id: str = Field(primary_key=True)
base_model: str
lora_config: dict[str, object] = Field(sa_type=JSON)
status: str = Field(index=True)
request_id: int
session_id: str = Field(foreign_key="sessions.session_id", index=True)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), sa_type=DateTime(timezone=True))
class FutureDB(SQLModel, table=True):
__tablename__ = "futures"
request_id: int | None = Field(default=None, primary_key=True, sa_column_kwargs={"autoincrement": True})
request_type: types.RequestType
model_id: str | None = Field(default=None, index=True)
request_data: dict = Field(sa_type=JSON) # this is of type types.{request_type}Input
result_data: dict | None = Field(default=None, sa_type=JSON) # this is of type types.{request_type}Output
status: RequestStatus = Field(default=RequestStatus.PENDING, index=True)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), sa_type=DateTime(timezone=True))
completed_at: datetime | None = Field(default=None, sa_type=DateTime(timezone=True))
class CheckpointDB(SQLModel, table=True):
__tablename__ = "checkpoints"
model_id: str = Field(foreign_key="models.model_id", primary_key=True)
checkpoint_id: str = Field(primary_key=True)
checkpoint_type: types.CheckpointType = Field(primary_key=True)
status: CheckpointStatus
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), sa_type=DateTime(timezone=True))
completed_at: datetime | None = Field(default=None, sa_type=DateTime(timezone=True))
error_message: str | None = None
class SessionDB(SQLModel, table=True):
__tablename__ = "sessions"
session_id: str = Field(primary_key=True)
tags: list[str] = Field(default_factory=list, sa_type=JSON)
user_metadata: dict = Field(default_factory=dict, sa_type=JSON)
sdk_version: str
status: str = Field(default="active", index=True)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), sa_type=DateTime(timezone=True))
last_heartbeat_at: datetime | None = Field(default=None, sa_type=DateTime(timezone=True), index=True)
heartbeat_count: int = 0
class SamplingSessionDB(SQLModel, table=True):
__tablename__ = "sampling_sessions"
sampling_session_id: str = Field(primary_key=True)
session_id: str = Field(foreign_key="sessions.session_id", index=True)
sampling_session_seq_id: int
base_model: str | None = None
model_path: str | None = None
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), sa_type=DateTime(timezone=True))
class EngineStateDB(SQLModel, table=True):
"""Engine→API handoff for the inference engine the backend stands up.
Singleton row (``singleton_id=1``). Written by the backend when a new
inference client is built (or torn down) and read by the API's
forwarding client to resolve the vLLM proxy URL.
"""
__tablename__ = "engine_state"
singleton_id: int = Field(default=1, primary_key=True)
# Proxy URL of the engine-managed vLLM. None when no vLLM has been
# stood up yet (no create_model, FFT path, or last delete tore down).
inference_proxy_url: str | None = None
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), sa_type=DateTime(timezone=True))