-
Notifications
You must be signed in to change notification settings - Fork 175
Expand file tree
/
Copy pathentity.py
More file actions
97 lines (81 loc) · 2.88 KB
/
entity.py
File metadata and controls
97 lines (81 loc) · 2.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import enum
from functools import lru_cache
import sys
from typing import Optional, List, Dict, Type
from sqlmodel import (
SQLModel,
Field,
Column,
JSON,
Text,
)
from pydantic import BaseModel
from tidb_vector.sqlalchemy import VectorType
from sqlalchemy import Index
from app.models.knowledge_base import KnowledgeBase
from app.models.knowledge_base_scoped.table_naming import get_kb_vector_dims
from app.utils.namespace import format_namespace
from app.logger import logger
class EntityType(str, enum.Enum):
original = "original"
synopsis = "synopsis"
def __str__(self):
return self.value
class EntityPublic(BaseModel):
id: int
entity_type: EntityType = Field(default=EntityType.original)
name: str
description: Optional[str] = Field(default=None)
meta: Optional[dict] = Field(default=None)
synopsis_info: Optional[dict] = Field(default=None)
def get_kb_entity_model(kb: KnowledgeBase) -> Type[SQLModel]:
vector_dimension = get_kb_vector_dims(kb)
return get_dynamic_entity_model(vector_dimension, str(kb.id))
@lru_cache(maxsize=sys.maxsize)
def get_dynamic_entity_model(
vector_dimension: int,
namespace: Optional[str] = None,
) -> Type[SQLModel]:
namespace = format_namespace(namespace)
entity_table_name = f"entities_{namespace}"
entity_model_name = f"Entity_{namespace}_{vector_dimension}"
logger.info(
"Dynamic create entity model (dimension: %s, table: %s, model: %s)",
vector_dimension,
entity_table_name,
entity_model_name,
)
class Entity(SQLModel):
id: Optional[int] = Field(default=None, primary_key=True)
name: str = Field(max_length=512)
description: str = Field(sa_column=Column(Text))
meta: dict = Field(default_factory=dict, sa_column=Column(JSON))
entity_type: EntityType = EntityType.original
synopsis_info: List | Dict | None = Field(default=None, sa_column=Column(JSON))
description_vec: list[float] = Field(sa_type=VectorType(vector_dimension))
meta_vec: list[float] = Field(sa_type=VectorType(vector_dimension))
def __hash__(self):
return hash(self.id)
# screenshot method is used to return a dictionary representation of the object
# that can be used for recording or debugging purposes
def screenshot(self):
return self.model_dump(
exclude={
"description_vec",
"meta_vec",
}
)
entity_model = type(
entity_model_name,
(Entity,),
{
"__tablename__": entity_table_name,
"__table_args__": (
Index("idx_entity_type", "entity_type"),
Index("idx_entity_name", "name"),
{"extend_existing": True},
),
},
table=True,
)
return entity_model