|
2 | 2 |
|
3 | 3 | import asyncio |
4 | 4 | import concurrent.futures |
| 5 | +import contextvars |
5 | 6 | import functools |
6 | 7 | import json |
7 | 8 | import logging |
@@ -102,6 +103,50 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: |
102 | 103 | # Create router |
103 | 104 | kb_router = APIRouter(prefix="/api/kb", tags=["kb"]) |
104 | 105 |
|
| 106 | +# Shared executor for ingestion tasks to prevent global thread pool exhaustion |
| 107 | +_ingest_executor = concurrent.futures.ThreadPoolExecutor( |
| 108 | + max_workers=10, thread_name_prefix="ingest_worker_" |
| 109 | +) |
| 110 | + |
| 111 | + |
| 112 | +def shutdown_ingest_executor() -> None: |
| 113 | + """Shutdown the shared ingestion executor gracefully. |
| 114 | +
|
| 115 | + Called during application shutdown to ensure pending tasks complete |
| 116 | + and resources are properly released. Uses a timeout to prevent blocking |
| 117 | + application shutdown indefinitely. |
| 118 | + """ |
| 119 | + logger.info("Shutting down ingestion executor...") |
| 120 | + try: |
| 121 | + # Use threading.Event to implement timeout with executor shutdown |
| 122 | + import threading |
| 123 | + |
| 124 | + shutdown_complete = threading.Event() |
| 125 | + |
| 126 | + def wait_for_shutdown() -> None: |
| 127 | + _ingest_executor.shutdown(wait=True) |
| 128 | + shutdown_complete.set() |
| 129 | + |
| 130 | + shutdown_thread = threading.Thread(target=wait_for_shutdown, daemon=True) |
| 131 | + shutdown_thread.start() |
| 132 | + |
| 133 | + # Wait for shutdown with timeout |
| 134 | + if shutdown_complete.wait(timeout=30): |
| 135 | + logger.info("Ingestion executor shutdown complete") |
| 136 | + else: |
| 137 | + logger.warning( |
| 138 | + "Executor shutdown timed out after 30s; forcing shutdown. " |
| 139 | + "Some ingestion tasks may be incomplete." |
| 140 | + ) |
| 141 | + _ingest_executor.shutdown(wait=False) |
| 142 | + except Exception as e: |
| 143 | + logger.error("Error during executor shutdown: %s", e) |
| 144 | + # Force shutdown on any error |
| 145 | + try: |
| 146 | + _ingest_executor.shutdown(wait=False) |
| 147 | + except Exception: |
| 148 | + pass # Ignore errors during forced shutdown |
| 149 | + |
105 | 150 |
|
106 | 151 | class CloudFile(BaseModel): |
107 | 152 | provider: str |
@@ -427,19 +472,18 @@ async def ingest( |
427 | 472 |
|
428 | 473 | progress_manager = get_progress_manager() |
429 | 474 |
|
430 | | - def _run_ingestion() -> IngestionResult: |
431 | | - return run_document_ingestion( |
| 475 | + result: IngestionResult = await asyncio.get_running_loop().run_in_executor( |
| 476 | + _ingest_executor, |
| 477 | + lambda: contextvars.copy_context().run( |
| 478 | + run_document_ingestion, |
432 | 479 | collection=collection, |
433 | 480 | source_path=str(file_path), |
434 | 481 | ingestion_config=config, |
435 | 482 | progress_manager=progress_manager, |
436 | 483 | user_id=int(_user.id), |
437 | 484 | is_admin=bool(_user.is_admin), |
438 | | - ) |
439 | | - |
440 | | - with concurrent.futures.ThreadPoolExecutor() as executor: |
441 | | - future = executor.submit(_run_ingestion) |
442 | | - result: IngestionResult = future.result() |
| 485 | + ), |
| 486 | + ) |
443 | 487 |
|
444 | 488 | if result.status == "error": |
445 | 489 | return JSONResponse(status_code=500, content=result.model_dump()) |
@@ -545,14 +589,17 @@ def _download_file() -> None: |
545 | 589 |
|
546 | 590 | # Run ingestion (blocking) |
547 | 591 | try: |
548 | | - result = await asyncio.to_thread( |
549 | | - run_document_ingestion, |
550 | | - collection=request.collection, |
551 | | - source_path=str(file_path), |
552 | | - ingestion_config=config, |
553 | | - progress_manager=progress_manager, |
554 | | - user_id=int(_user.id), |
555 | | - is_admin=bool(_user.is_admin), |
| 592 | + result = await asyncio.get_running_loop().run_in_executor( |
| 593 | + _ingest_executor, |
| 594 | + lambda: contextvars.copy_context().run( |
| 595 | + run_document_ingestion, |
| 596 | + collection=request.collection, |
| 597 | + source_path=str(file_path), |
| 598 | + ingestion_config=config, |
| 599 | + progress_manager=progress_manager, |
| 600 | + user_id=int(_user.id), |
| 601 | + is_admin=bool(_user.is_admin), |
| 602 | + ), |
556 | 603 | ) |
557 | 604 | return result |
558 | 605 | except Exception as e: |
@@ -864,15 +911,16 @@ async def ingest_web( |
864 | 911 | max_retries: Maximum retry attempts |
865 | 912 | retry_delay: Delay between retries |
866 | 913 | """ |
| 914 | + # SECURITY: Validate collection name at API boundary |
867 | 915 | try: |
868 | | - try: |
869 | | - safe_collection = sanitize_path_component(collection, "collection") |
870 | | - except ValueError as e: |
871 | | - logger.warning("Invalid collection name rejected: %s - %s", collection, e) |
872 | | - raise HTTPException( |
873 | | - status_code=422, detail=f"Invalid collection name: {str(e)}" |
874 | | - ) from e |
| 916 | + safe_collection = sanitize_path_component(collection, "collection") |
| 917 | + except ValueError as e: |
| 918 | + logger.warning("Invalid collection name rejected: %s - %s", collection, e) |
| 919 | + raise HTTPException( |
| 920 | + status_code=422, detail=f"Invalid collection name: {str(e)}" |
| 921 | + ) from e |
875 | 922 |
|
| 923 | + try: |
876 | 924 | url_patterns_list = ( |
877 | 925 | [p.strip() for p in url_patterns.split(",")] if url_patterns else None |
878 | 926 | ) |
@@ -958,16 +1006,18 @@ async def ingest_web( |
958 | 1006 | ), |
959 | 1007 | ) |
960 | 1008 |
|
961 | | - result = await asyncio.get_event_loop().run_in_executor( |
962 | | - None, |
963 | | - lambda: asyncio.run( |
| 1009 | + # Run web ingestion |
| 1010 | + result = await asyncio.get_running_loop().run_in_executor( |
| 1011 | + _ingest_executor, |
| 1012 | + lambda: contextvars.copy_context().run( |
| 1013 | + asyncio.run, |
964 | 1014 | run_web_ingestion( |
965 | 1015 | collection=safe_collection, |
966 | 1016 | crawl_config=crawl_config, |
967 | 1017 | ingestion_config=ingestion_config, |
968 | 1018 | user_id=int(_user.id), |
969 | 1019 | is_admin=bool(_user.is_admin), |
970 | | - ) |
| 1020 | + ), |
971 | 1021 | ), |
972 | 1022 | ) |
973 | 1023 |
|
@@ -1321,6 +1371,10 @@ async def delete_document_api( |
1321 | 1371 | This endpoint uses filename lookup which may have a race condition if |
1322 | 1372 | the same filename is uploaded multiple times concurrently. For production |
1323 | 1373 | use, consider using doc_id directly or adding a filename index column. |
| 1374 | +
|
| 1375 | + Physical file cleanup is handled at the collection level (when the entire |
| 1376 | + collection is deleted). This endpoint only removes the document from |
| 1377 | + LanceDB vector storage. |
1324 | 1378 | """ |
1325 | 1379 | # NOTE: Exceptions are normalized by @handle_kb_exceptions for consistent API responses. |
1326 | 1380 | from ...core.tools.core.RAG_tools.LanceDB.schema_manager import ( |
@@ -1642,33 +1696,66 @@ async def rename_collection_api( |
1642 | 1696 | physical_rename_error = f"Path resolution error: {str(e)}" |
1643 | 1697 |
|
1644 | 1698 | # Step 2: Update collection name in all tables |
| 1699 | + # NOTE: LanceDB does not support ACID transactions across multiple tables. |
| 1700 | + # We use a best-effort approach: validate all tables are accessible first, |
| 1701 | + # then perform updates. If any update fails, we record the failure but cannot |
| 1702 | + # roll back previously completed updates. |
1645 | 1703 | table_names = _list_table_names(conn, warnings) |
1646 | 1704 |
|
| 1705 | + # Pre-flight validation: ensure all target tables are accessible |
| 1706 | + tables_to_update = [] |
1647 | 1707 | for table_name in ["documents", "parses", "chunks"]: |
1648 | 1708 | if table_name in table_names: |
1649 | | - try: |
1650 | | - table = conn.open_table(table_name) |
1651 | | - table.update( |
1652 | | - f"collection = '{escape_lancedb_string(collection_name)}'", |
1653 | | - {"collection": new_name}, |
1654 | | - ) |
1655 | | - except Exception as e: |
1656 | | - logger.warning("Failed to update '%s': %s", table_name, e) |
1657 | | - warnings.append(f"Failed to update '{table_name}': {e}") |
| 1709 | + tables_to_update.append(table_name) |
1658 | 1710 |
|
1659 | 1711 | for table_name in table_names: |
1660 | | - if not table_name.startswith("embeddings_"): |
1661 | | - continue |
| 1712 | + if table_name.startswith("embeddings_"): |
| 1713 | + tables_to_update.append(table_name) |
| 1714 | + |
| 1715 | + # Validate all tables can be opened |
| 1716 | + validation_errors = [] |
| 1717 | + for table_name in tables_to_update: |
| 1718 | + try: |
| 1719 | + conn.open_table(table_name) |
| 1720 | + except Exception as e: |
| 1721 | + validation_errors.append(f"Cannot access table '{table_name}': {e}") |
| 1722 | + |
| 1723 | + if validation_errors: |
| 1724 | + # Fail fast if any table is inaccessible |
| 1725 | + error_msg = f"Cannot rename collection: {', '.join(validation_errors)}" |
| 1726 | + logger.error("Rename validation failed: %s", error_msg) |
| 1727 | + raise HTTPException( |
| 1728 | + status_code=500, |
| 1729 | + detail=error_msg, |
| 1730 | + ) |
| 1731 | + |
| 1732 | + # Perform all updates (best-effort, no rollback) |
| 1733 | + update_success = {} |
| 1734 | + update_failed = {} |
| 1735 | + |
| 1736 | + for table_name in tables_to_update: |
1662 | 1737 | try: |
1663 | 1738 | table = conn.open_table(table_name) |
1664 | 1739 | table.update( |
1665 | | - f"collection = '{escape_lancedb_string(collection_name)}'", |
| 1740 | + f"collection = '{escape_lancedb_string(collection_name)}' AND user_id = {int(_user.id)}", |
1666 | 1741 | {"collection": new_name}, |
1667 | 1742 | ) |
| 1743 | + update_success[table_name] = True |
1668 | 1744 | except Exception as e: |
1669 | | - logger.warning("Failed to update embeddings table '%s': %s", table_name, e) |
| 1745 | + update_failed[table_name] = str(e) |
| 1746 | + logger.error("Failed to update table '%s': %s", table_name, e) |
1670 | 1747 | warnings.append(f"Failed to update '{table_name}': {e}") |
1671 | 1748 |
|
| 1749 | + # Log summary of update results |
| 1750 | + if update_failed: |
| 1751 | + logger.warning( |
| 1752 | + "Collection rename update summary: %d succeeded, %d failed. Success: %s, Failed: %s", |
| 1753 | + len(update_success), |
| 1754 | + len(update_failed), |
| 1755 | + list(update_success.keys()), |
| 1756 | + list(update_failed.keys()), |
| 1757 | + ) |
| 1758 | + |
1672 | 1759 | # Migrate ingestion status from old collection name to new |
1673 | 1760 | try: |
1674 | 1761 | status_entries = load_ingestion_status(collection=collection_name) |
@@ -1734,11 +1821,13 @@ async def rename_collection_api( |
1734 | 1821 | "status": final_status, |
1735 | 1822 | "message": final_message, |
1736 | 1823 | "warnings": warnings, |
| 1824 | + "physical_move": physical_rename_status, |
1737 | 1825 | } |
1738 | 1826 |
|
1739 | 1827 | return { |
1740 | 1828 | "status": "success", |
1741 | 1829 | "message": f"Collection renamed from '{collection_name}' to '{new_name}'", |
| 1830 | + "physical_move": physical_rename_status, |
1742 | 1831 | } |
1743 | 1832 |
|
1744 | 1833 |
|
|
0 commit comments