Skip to content

Commit 091aa72

Browse files
authored
refactor: refine dynamic model creation and remove PatchSQLModel (#660)
- using `@lru_cache` to avoid creating duplicate dynamic chunk/entity/relationship models. - using `type` function to dynamic create model class, `PatchSQLModel` and custom `SQLModelMetaclass` is no more need
1 parent 1d0007f commit 091aa72

File tree

26 files changed

+251
-985
lines changed

26 files changed

+251
-985
lines changed

backend/app/api/admin_routes/knowledge_base/document/routes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from app.api.admin_routes.knowledge_base.models import ChunkItem
99
from app.api.deps import SessionDep, CurrentSuperuserDep
1010
from app.models import Document
11-
from app.models.chunk import Chunk, KgIndexStatus, get_kb_chunk_model
11+
from app.models.chunk import KgIndexStatus, get_kb_chunk_model
1212
from app.models.document import DocIndexTaskStatus
1313
from app.models.entity import get_kb_entity_model
1414
from app.models.relationship import get_kb_relationship_model
@@ -198,7 +198,7 @@ def rebuild_kb_document_index_by_ids(
198198
build_index_for_document.delay(kb.id, doc.id)
199199

200200
# Retry failed kg index tasks.
201-
chunks: list[Chunk] = kb_chunk_repo.fetch_by_document_ids(db_session, document_ids)
201+
chunks = kb_chunk_repo.fetch_by_document_ids(db_session, document_ids)
202202
reindex_chunk_ids = []
203203
ignore_chunk_ids = []
204204
for chunk in chunks:

backend/app/api/routes/api_key.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class CreateApiKeyResponse(BaseModel):
2121
async def create_api_key(
2222
session: AsyncSessionDep, user: CurrentSuperuserDep, request: CreateApiKeyRequest
2323
) -> CreateApiKeyResponse:
24-
_, raw_api_key = await api_key_manager.create_api_key(
24+
_, raw_api_key = await api_key_manager.acreate_api_key(
2525
session, user, request.description
2626
)
2727
return CreateApiKeyResponse(api_key=raw_api_key)

backend/app/auth/api_keys.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Optional, Tuple
66

77
from fastapi import Request
8-
from sqlmodel import select
8+
from sqlmodel import Session, select
99
from sqlmodel.ext.asyncio.session import AsyncSession
1010
from fastapi_pagination import Params, Page
1111
from fastapi_pagination.ext.sqlmodel import paginate
@@ -49,7 +49,7 @@ def encrypt_api_key(api_key: str) -> str:
4949

5050

5151
class ApiKeyManager:
52-
async def create_api_key(
52+
async def acreate_api_key(
5353
self, session: AsyncSession, user: User, description: str
5454
) -> Tuple[ApiKey, str]:
5555
api_key = generate_api_key()
@@ -65,6 +65,22 @@ async def create_api_key(
6565
await session.refresh(api_key_obj)
6666
return api_key_obj, api_key
6767

68+
def create_api_key(
69+
self, session: Session, user: User, description: str
70+
) -> Tuple[ApiKey, str]:
71+
api_key = generate_api_key()
72+
hashed_api_key = encrypt_api_key(api_key)
73+
api_key_obj = ApiKey(
74+
hashed_secret=hashed_api_key,
75+
api_key_display=api_key[:7] + "...." + api_key[-3:],
76+
user_id=user.id,
77+
description=description,
78+
)
79+
session.add(api_key_obj)
80+
session.commit()
81+
session.refresh(api_key_obj)
82+
return api_key_obj, api_key
83+
6884
async def get_active_user_by_raw_api_key(
6985
self, session: AsyncSession, api_key: str
7086
) -> Optional[User]:

backend/app/models/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# flake8: noqa
22
from .entity import (
33
EntityType,
4-
Entity,
54
EntityPublic,
5+
get_kb_entity_model,
66
)
7-
from .relationship import Relationship, RelationshipPublic
7+
from .relationship import RelationshipPublic, get_kb_relationship_model
88
from .feedback import (
99
Feedback,
1010
FeedbackType,
@@ -18,7 +18,7 @@
1818
from .chat import Chat, ChatUpdate, ChatVisibility, ChatFilters, ChatOrigin
1919
from .chat_message import ChatMessage
2020
from .document import Document, DocIndexTaskStatus
21-
from .chunk import Chunk, KgIndexStatus
21+
from .chunk import KgIndexStatus, get_kb_chunk_model
2222
from .auth import User, UserSession
2323
from .api_key import ApiKey, PublicApiKey
2424
from .site_setting import SiteSetting

backend/app/models/chunk.py

Lines changed: 36 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
import enum
2-
from datetime import datetime
3-
from typing import Any, Optional, Type
4-
from uuid import UUID
2+
from functools import lru_cache
53

6-
from sqlalchemy import DateTime, func
4+
from typing import Optional, Type
75
from sqlmodel import (
86
Field,
97
Column,
@@ -15,17 +13,11 @@
1513
from tidb_vector.sqlalchemy import VectorType
1614
from llama_index.core.schema import TextNode
1715

18-
from app.core.config import settings
1916
from app.models.document import Document
2017
from app.models.knowledge_base import KnowledgeBase
18+
from app.models.knowledge_base_scoped.table_naming import get_kb_vector_dims
19+
from app.utils.namespace import format_namespace
2120
from .base import UpdatableBaseModel, UUIDBaseModel
22-
from app.models.knowledge_base_scoped.registry import get_kb_scoped_registry
23-
from .knowledge_base_scoped.table_naming import (
24-
get_kb_chunks_table_name,
25-
get_kb_vector_dims,
26-
)
27-
from app.models.patch.sql_model import SQLModel as PatchSQLModel
28-
from ..utils.uuid6 import uuid7
2921

3022

3123
class KgIndexStatus(str, enum.Enum):
@@ -36,69 +28,26 @@ class KgIndexStatus(str, enum.Enum):
3628
FAILED = "failed"
3729

3830

39-
# Notice: DO NOT forget to modify the definition in `get_kb_chunk_model` to
40-
# keep the table structure on both sides consistent.
41-
class Chunk(UUIDBaseModel, UpdatableBaseModel, table=True):
42-
hash: str = Field(max_length=64)
43-
text: str = Field(sa_column=Column(Text))
44-
meta: dict | list = Field(default={}, sa_column=Column(JSON))
45-
embedding: Any = Field(
46-
sa_column=Column(
47-
VectorType(settings.EMBEDDING_DIMS), comment="hnsw(distance=cosine)"
48-
)
49-
)
50-
document_id: int = Field(foreign_key="documents.id", nullable=True)
51-
document: "Document" = SQLRelationship(
52-
sa_relationship_kwargs={
53-
"lazy": "joined",
54-
"primaryjoin": "Chunk.document_id == Document.id",
55-
},
56-
)
57-
relations: dict | list = Field(default={}, sa_column=Column(JSON))
58-
source_uri: str = Field(max_length=512, nullable=True)
59-
60-
# TODO: Add vector_index_status, vector_index_result column, vector index should be optional in the future.
61-
62-
# TODO: Rename to kg_index_status, kg_index_result column.
63-
index_status: KgIndexStatus = KgIndexStatus.NOT_STARTED
64-
index_result: str = Field(sa_column=Column(Text, nullable=True))
65-
66-
__tablename__ = "chunks"
67-
68-
def to_llama_text_node(self) -> TextNode:
69-
return TextNode(
70-
id_=self.id.hex,
71-
text=self.text,
72-
embedding=list(self.embedding),
73-
metadata=self.meta,
74-
)
75-
76-
7731
def get_kb_chunk_model(kb: KnowledgeBase) -> Type[SQLModel]:
7832
vector_dimension = get_kb_vector_dims(kb)
79-
chunks_table_name = get_kb_chunks_table_name(kb)
80-
ctx = get_kb_scoped_registry(kb)
33+
return get_dynamic_chunk_model(vector_dimension, str(kb.id))
8134

82-
if ctx.chunk_model:
83-
return ctx.chunk_model
8435

85-
class KBChunk(PatchSQLModel, table=True, registry=ctx.registry):
86-
__tablename__ = chunks_table_name
87-
__table_args__ = {"extend_existing": True}
36+
@lru_cache(maxsize=None)
37+
def get_dynamic_chunk_model(
38+
vector_dimension: int,
39+
namespace: Optional[str] = None,
40+
) -> Type[SQLModel]:
41+
namespace = format_namespace(namespace)
42+
chunk_table_name = f"chunks_{namespace}"
43+
chunk_model_name = f"Chunk_{namespace}_{vector_dimension}"
8844

89-
id: UUID = Field(
90-
primary_key=True, index=True, nullable=False, default_factory=uuid7
91-
)
45+
class Chunk(UUIDBaseModel, UpdatableBaseModel):
9246
hash: str = Field(max_length=64)
9347
text: str = Field(sa_column=Column(Text))
94-
meta: dict | list = Field(default={}, sa_column=Column(JSON))
95-
embedding: Any = Field(
96-
sa_column=Column(
97-
VectorType(vector_dimension), comment="hnsw(distance=cosine)"
98-
)
99-
)
48+
meta: dict = Field(default={}, sa_column=Column(JSON))
49+
embedding: list[float] = Field(sa_type=VectorType(vector_dimension))
10050
document_id: int = Field(foreign_key="documents.id", nullable=True)
101-
document: "Document" = SQLRelationship()
10251
relations: dict | list = Field(default={}, sa_column=Column(JSON))
10352
source_uri: str = Field(max_length=512, nullable=True)
10453

@@ -108,16 +57,6 @@ class KBChunk(PatchSQLModel, table=True, registry=ctx.registry):
10857
index_status: KgIndexStatus = KgIndexStatus.NOT_STARTED
10958
index_result: str = Field(sa_column=Column(Text, nullable=True))
11059

111-
created_at: Optional[datetime] = Field(
112-
default=None,
113-
sa_column=Column(DateTime(timezone=True), server_default=func.now()),
114-
)
115-
updated_at: Optional[datetime] = Field(
116-
default=None,
117-
sa_type=DateTime(timezone=True),
118-
sa_column_kwargs={"server_default": func.now(), "onupdate": func.now()},
119-
)
120-
12160
def to_llama_text_node(self) -> TextNode:
12261
return TextNode(
12362
id_=self.id.hex,
@@ -126,5 +65,23 @@ def to_llama_text_node(self) -> TextNode:
12665
metadata=self.meta,
12766
)
12867

129-
ctx.chunk_model = KBChunk
130-
return KBChunk
68+
chunk_model = type(
69+
chunk_model_name,
70+
(Chunk,),
71+
{
72+
"__tablename__": chunk_table_name,
73+
"__table_args__": {"extend_existing": True},
74+
"__annotations__": {
75+
"document": Document,
76+
},
77+
"document": SQLRelationship(
78+
sa_relationship_kwargs={
79+
"lazy": "joined",
80+
"primaryjoin": f"{chunk_model_name}.document_id == Document.id",
81+
},
82+
),
83+
},
84+
table=True,
85+
)
86+
87+
return chunk_model

0 commit comments

Comments
 (0)