|
5 | 5 | from sqlmodel import func, select |
6 | 6 |
|
7 | 7 | from owl.db import AsyncSession, yield_async_session |
8 | | -from owl.db.models import Deployment, ModelConfig |
| 8 | +from owl.db.models import Deployment, ModelConfig, Organization |
| 9 | +from owl.tasks.gen_table import replace_gen_table_models |
9 | 10 | from owl.types import ( |
10 | 11 | CloudProvider, |
11 | 12 | DeploymentCreate, |
12 | 13 | DeploymentRead, |
13 | 14 | DeploymentUpdate, |
| 15 | + GenTableModelReplaceProgressKeys, |
| 16 | + GenTableModelReplaceRequest, |
14 | 17 | ListQuery, |
15 | 18 | ModelConfigCreate, |
16 | 19 | ModelConfigRead, |
17 | 20 | ModelConfigUpdate, |
| 21 | + ModelType, |
18 | 22 | OkResponse, |
19 | 23 | Page, |
20 | 24 | UserAuth, |
21 | 25 | ) |
| 26 | +from owl.utils import uuid7_str |
22 | 27 | from owl.utils.auth import auth_user_service_key, has_permissions |
23 | 28 | from owl.utils.dates import now |
24 | 29 | from owl.utils.exceptions import ( |
|
27 | 32 | ResourceNotFoundError, |
28 | 33 | handle_exception, |
29 | 34 | ) |
| 35 | +from owl.utils.gen_table_model_replace import ( |
| 36 | + acquire_model_replace_lock, |
| 37 | + delete_model_replace_progress, |
| 38 | + get_model_replace_lock, |
| 39 | + initialize_model_replace_progress, |
| 40 | + list_recent_model_replace_progress_keys, |
| 41 | + release_model_replace_lock, |
| 42 | +) |
30 | 43 |
|
31 | 44 | router = APIRouter() |
32 | 45 |
|
@@ -316,3 +329,76 @@ async def delete_deployment( |
316 | 329 | await session.delete(deployment) |
317 | 330 | await session.commit() |
318 | 331 | return OkResponse() |
| 332 | + |
| 333 | + |
| 334 | +@router.post( |
| 335 | + "/v2/models/replace", |
| 336 | + summary="Replace model IDs in GenTable generation configs.", |
| 337 | + description="Permissions: `system.MEMBER`.", |
| 338 | +) |
| 339 | +@handle_exception |
| 340 | +async def replace_gen_table_model_ids( |
| 341 | + request: Request, |
| 342 | + user: Annotated[UserAuth, Depends(auth_user_service_key)], |
| 343 | + session: Annotated[AsyncSession, Depends(yield_async_session)], |
| 344 | + body: GenTableModelReplaceRequest, |
| 345 | +) -> OkResponse: |
| 346 | + logger.info(f"{request.state.id} - Replacing GenTable model IDs: {body}") |
| 347 | + has_permissions(user, ["system.MEMBER"]) |
| 348 | + if body.organization_ids is not None: |
| 349 | + for organization_id in body.organization_ids: |
| 350 | + await Organization.get(session, organization_id, name="Organization") |
| 351 | + model_configs = { |
| 352 | + model_id: await ModelConfig.get(session, model_id, name="Model config") |
| 353 | + for model_id in sorted(set(body.mapping) | set(body.mapping.values())) |
| 354 | + } |
| 355 | + embedding_model_ids = [ |
| 356 | + model_id |
| 357 | + for model_id, model_config in model_configs.items() |
| 358 | + if model_config is not None and model_config.type == ModelType.EMBED |
| 359 | + ] |
| 360 | + if embedding_model_ids: |
| 361 | + raise BadInputError(f"Replacing embedding model is not supported: {embedding_model_ids}.") |
| 362 | + progress_key = f"gen_table_model_replace:{uuid7_str()}" |
| 363 | + if not await acquire_model_replace_lock(progress_key): |
| 364 | + active_progress_key = await get_model_replace_lock() |
| 365 | + raise ResourceExistsError( |
| 366 | + ( |
| 367 | + "Another GenTable model replace task is already running" |
| 368 | + + (f': progress_key="{active_progress_key}".' if active_progress_key else ".") |
| 369 | + ) |
| 370 | + ) |
| 371 | + try: |
| 372 | + await initialize_model_replace_progress( |
| 373 | + progress_key=progress_key, |
| 374 | + mapping=body.mapping, |
| 375 | + organization_ids=body.organization_ids, |
| 376 | + requested_by=user.id, |
| 377 | + ) |
| 378 | + replace_gen_table_models.delay( |
| 379 | + mapping=body.mapping, |
| 380 | + organization_ids=body.organization_ids, |
| 381 | + progress_key=progress_key, |
| 382 | + requested_by=user.id, |
| 383 | + ) |
| 384 | + except Exception: |
| 385 | + await release_model_replace_lock(progress_key) |
| 386 | + await delete_model_replace_progress(progress_key) |
| 387 | + raise |
| 388 | + logger.bind(user_id=user.id).info( |
| 389 | + f'Enqueued GenTable model replace task progress_key="{progress_key}".' |
| 390 | + ) |
| 391 | + return OkResponse(progress_key=progress_key) |
| 392 | + |
| 393 | + |
| 394 | +@router.get( |
| 395 | + "/v2/models/replace/progress_keys", |
| 396 | + summary="List recent GenTable model replacement progress keys.", |
| 397 | + description="Permissions: `system.MEMBER`.", |
| 398 | +) |
| 399 | +@handle_exception |
| 400 | +async def list_gen_table_model_replace_progress_keys( |
| 401 | + user: Annotated[UserAuth, Depends(auth_user_service_key)], |
| 402 | +) -> GenTableModelReplaceProgressKeys: |
| 403 | + has_permissions(user, ["system.MEMBER"]) |
| 404 | + return GenTableModelReplaceProgressKeys(items=await list_recent_model_replace_progress_keys()) |
0 commit comments