Skip to content

Commit 96dd12c

Browse files
committed
[owl] Add model replacement route (#922)
added an admin model management route to replace the model used in a gentable currently doesn't support replacing embedding model hardened redis async client creation with a loop-aware design
1 parent 0b2e755 commit 96dd12c

12 files changed

Lines changed: 1464 additions & 22 deletions

File tree

clients/python/src/jamaibase/client.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
EmbeddingResponse,
4343
FileUploadResponse,
4444
GenConfigUpdateRequest,
45+
GenTableModelReplaceProgressKeys,
46+
GenTableModelReplaceRequest,
4547
GetURLRequest,
4648
GetURLResponse,
4749
KnowledgeTableSchemaCreate,
@@ -1229,6 +1231,37 @@ async def delete_deployment(self, deployment_id: str, **kwargs) -> OkResponse:
12291231
**kwargs,
12301232
)
12311233

1234+
async def replace_model_ids(
1235+
self,
1236+
request: GenTableModelReplaceRequest,
1237+
**kwargs,
1238+
) -> OkResponse:
1239+
"""
1240+
Replace model IDs in GenTable generation configs.
1241+
1242+
Args:
1243+
request (GenTableModelReplaceRequest): The model replacement request.
1244+
1245+
Returns:
1246+
response (OkResponse): Response containing the progress key.
1247+
"""
1248+
return await self._post(
1249+
"/v2/models/replace",
1250+
body=request,
1251+
response_model=OkResponse,
1252+
**kwargs,
1253+
)
1254+
1255+
async def list_model_replace_progress_keys(
1256+
self,
1257+
**kwargs,
1258+
) -> GenTableModelReplaceProgressKeys:
1259+
return await self._get(
1260+
"/v2/models/replace/progress_keys",
1261+
response_model=GenTableModelReplaceProgressKeys,
1262+
**kwargs,
1263+
)
1264+
12321265

12331266
class _OrganizationsAsync(_ClientAsync):
12341267
"""Organization methods."""
@@ -4731,6 +4764,19 @@ def update_deployment(
47314764
def delete_deployment(self, deployment_id: str, **kwargs) -> OkResponse:
47324765
return LOOP.run(super().delete_deployment(deployment_id, **kwargs))
47334766

4767+
def replace_model_ids(
4768+
self,
4769+
request: GenTableModelReplaceRequest,
4770+
**kwargs,
4771+
) -> OkResponse:
4772+
return LOOP.run(super().replace_model_ids(request, **kwargs))
4773+
4774+
def list_model_replace_progress_keys(
4775+
self,
4776+
**kwargs,
4777+
) -> GenTableModelReplaceProgressKeys:
4778+
return LOOP.run(super().list_model_replace_progress_keys(**kwargs))
4779+
47344780

47354781
class _Organizations(_OrganizationsAsync):
47364782
"""Organization methods."""

clients/python/src/jamaibase/types/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@
164164
DiscriminatedGenConfig,
165165
EmbedGenConfig,
166166
GenConfigUpdateRequest,
167+
GenTableModelReplaceProgressKeys,
168+
GenTableModelReplaceRequest,
167169
ImageGenConfig,
168170
KnowledgeTableSchemaCreate,
169171
LLMGenConfig,

clients/python/src/jamaibase/types/gen_table.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,31 @@ def check_column_map(self) -> Self:
392392
return self
393393

394394

395+
class GenTableModelReplaceRequest(BaseModel):
396+
mapping: dict[str, str] = Field(
397+
min_length=1,
398+
description="Mapping of old model IDs to replacement model IDs.",
399+
)
400+
organization_ids: list[str] | None = Field(
401+
None,
402+
description="Optional organization IDs to scan. If omitted, all organizations are scanned.",
403+
)
404+
405+
@model_validator(mode="after")
406+
def check_mapping(self) -> Self:
407+
for old_id, new_id in self.mapping.items():
408+
if old_id == new_id:
409+
raise ValueError(f'Model replacement maps "{old_id}" to itself.')
410+
return self
411+
412+
413+
class GenTableModelReplaceProgressKeys(BaseModel):
414+
items: list[str] = Field(
415+
default_factory=list,
416+
description="Recent GenTable model replacement progress keys, newest first.",
417+
)
418+
419+
395420
class ColumnRenameRequest(BaseModel):
396421
table_id: str = Field(
397422
description="Table name or ID.",

docker/override.dev.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ services:
5656
- ${PWD}/docker_data/owl/db:/usr/src/app/db
5757
- ${PWD}/services/api/src:/usr/src/app/api/src
5858

59+
starling:
60+
volumes:
61+
- ${PWD}/services/api/src:/usr/src/app/api/src
62+
5963
kopi:
6064
ports:
6165
- 5569:3000

services/api/src/owl/docparse.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,7 @@ async def load_document(
245245
md = await CACHE.get(cache_key)
246246
if md is not None:
247247
# Extend cache TTL
248-
await CACHE._redis_async.expire(
249-
cache_key, ENV_CONFIG.document_loader_cache_ttl_sec
250-
)
248+
await CACHE.expire(cache_key, ENV_CONFIG.document_loader_cache_ttl_sec)
251249
logger.debug(f'File "{file_name}" loaded from cache (cache key="{cache_key}").')
252250
return md
253251
try:
@@ -361,7 +359,7 @@ async def load_document_chunks(
361359
chunk_json_str = await CACHE.get(cache_key)
362360
if chunk_json_str is not None:
363361
# Extend cache TTL
364-
await CACHE._redis_async.expire(cache_key, cache_ttl)
362+
await CACHE.expire(cache_key, cache_ttl)
365363
logger.info(
366364
f'File chunks "{file_name}" loaded from cache (cache key="{cache_key}").'
367365
)

services/api/src/owl/routers/models.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,25 @@
55
from sqlmodel import func, select
66

77
from owl.db import AsyncSession, yield_async_session
8-
from owl.db.models import Deployment, ModelConfig
8+
from owl.db.models import Deployment, ModelConfig, Organization
9+
from owl.tasks.gen_table import replace_gen_table_models
910
from owl.types import (
1011
CloudProvider,
1112
DeploymentCreate,
1213
DeploymentRead,
1314
DeploymentUpdate,
15+
GenTableModelReplaceProgressKeys,
16+
GenTableModelReplaceRequest,
1417
ListQuery,
1518
ModelConfigCreate,
1619
ModelConfigRead,
1720
ModelConfigUpdate,
21+
ModelType,
1822
OkResponse,
1923
Page,
2024
UserAuth,
2125
)
26+
from owl.utils import uuid7_str
2227
from owl.utils.auth import auth_user_service_key, has_permissions
2328
from owl.utils.dates import now
2429
from owl.utils.exceptions import (
@@ -27,6 +32,14 @@
2732
ResourceNotFoundError,
2833
handle_exception,
2934
)
35+
from owl.utils.gen_table_model_replace import (
36+
acquire_model_replace_lock,
37+
delete_model_replace_progress,
38+
get_model_replace_lock,
39+
initialize_model_replace_progress,
40+
list_recent_model_replace_progress_keys,
41+
release_model_replace_lock,
42+
)
3043

3144
router = APIRouter()
3245

@@ -316,3 +329,76 @@ async def delete_deployment(
316329
await session.delete(deployment)
317330
await session.commit()
318331
return OkResponse()
332+
333+
334+
@router.post(
335+
"/v2/models/replace",
336+
summary="Replace model IDs in GenTable generation configs.",
337+
description="Permissions: `system.MEMBER`.",
338+
)
339+
@handle_exception
340+
async def replace_gen_table_model_ids(
341+
request: Request,
342+
user: Annotated[UserAuth, Depends(auth_user_service_key)],
343+
session: Annotated[AsyncSession, Depends(yield_async_session)],
344+
body: GenTableModelReplaceRequest,
345+
) -> OkResponse:
346+
logger.info(f"{request.state.id} - Replacing GenTable model IDs: {body}")
347+
has_permissions(user, ["system.MEMBER"])
348+
if body.organization_ids is not None:
349+
for organization_id in body.organization_ids:
350+
await Organization.get(session, organization_id, name="Organization")
351+
model_configs = {
352+
model_id: await ModelConfig.get(session, model_id, name="Model config")
353+
for model_id in sorted(set(body.mapping) | set(body.mapping.values()))
354+
}
355+
embedding_model_ids = [
356+
model_id
357+
for model_id, model_config in model_configs.items()
358+
if model_config is not None and model_config.type == ModelType.EMBED
359+
]
360+
if embedding_model_ids:
361+
raise BadInputError(f"Replacing embedding model is not supported: {embedding_model_ids}.")
362+
progress_key = f"gen_table_model_replace:{uuid7_str()}"
363+
if not await acquire_model_replace_lock(progress_key):
364+
active_progress_key = await get_model_replace_lock()
365+
raise ResourceExistsError(
366+
(
367+
"Another GenTable model replace task is already running"
368+
+ (f': progress_key="{active_progress_key}".' if active_progress_key else ".")
369+
)
370+
)
371+
try:
372+
await initialize_model_replace_progress(
373+
progress_key=progress_key,
374+
mapping=body.mapping,
375+
organization_ids=body.organization_ids,
376+
requested_by=user.id,
377+
)
378+
replace_gen_table_models.delay(
379+
mapping=body.mapping,
380+
organization_ids=body.organization_ids,
381+
progress_key=progress_key,
382+
requested_by=user.id,
383+
)
384+
except Exception:
385+
await release_model_replace_lock(progress_key)
386+
await delete_model_replace_progress(progress_key)
387+
raise
388+
logger.bind(user_id=user.id).info(
389+
f'Enqueued GenTable model replace task progress_key="{progress_key}".'
390+
)
391+
return OkResponse(progress_key=progress_key)
392+
393+
394+
@router.get(
395+
"/v2/models/replace/progress_keys",
396+
summary="List recent GenTable model replacement progress keys.",
397+
description="Permissions: `system.MEMBER`.",
398+
)
399+
@handle_exception
400+
async def list_gen_table_model_replace_progress_keys(
401+
user: Annotated[UserAuth, Depends(auth_user_service_key)],
402+
) -> GenTableModelReplaceProgressKeys:
403+
has_permissions(user, ["system.MEMBER"])
404+
return GenTableModelReplaceProgressKeys(items=await list_recent_model_replace_progress_keys())

services/api/src/owl/tasks/gen_table.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from owl.db.gen_table import ActionTable, ChatTable, KnowledgeTable
88
from owl.types import TableType
99
from owl.utils.exceptions import JamaiException, ResourceExistsError
10+
from owl.utils.gen_table_model_replace import GenTableModelReplacer, release_model_replace_lock
1011
from owl.utils.io import open_uri_async
1112

1213
TABLE_CLS: dict[TableType, ActionTable | KnowledgeTable | ChatTable] = {
@@ -60,3 +61,29 @@ async def _task():
6061
table = asyncio.get_event_loop().run_until_complete(_task())
6162
logger.info("Generative Table import task completed.")
6263
return table.v1_meta_response.model_dump_json()
64+
65+
66+
@celery_app.task
67+
def replace_gen_table_models(
68+
mapping: dict[str, str],
69+
*,
70+
organization_ids: list[str] | None = None,
71+
progress_key: str = "",
72+
requested_by: str = "",
73+
) -> dict[str, bool | int]:
74+
async def _task():
75+
try:
76+
replacer = GenTableModelReplacer(
77+
mapping=mapping,
78+
organization_ids=organization_ids,
79+
progress_key=progress_key,
80+
requested_by=requested_by,
81+
)
82+
return await replacer.run()
83+
finally:
84+
await release_model_replace_lock(progress_key)
85+
86+
logger.info("GenTable model replace task started.")
87+
result = asyncio.get_event_loop().run_until_complete(_task())
88+
logger.info("GenTable model replace task completed.")
89+
return result.to_dict()

services/api/src/owl/types/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@
7575
FunctionCall,
7676
FunctionParameters,
7777
GenConfigUpdateRequest,
78+
GenTableModelReplaceProgressKeys,
79+
GenTableModelReplaceRequest,
7880
GetURLRequest,
7981
GetURLResponse,
8082
Host,

0 commit comments

Comments
 (0)