Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions src/xagent/core/tools/core/RAG_tools/core/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1879,6 +1879,110 @@ def to_legacy_dicts(self) -> List[Dict[str, Any]]:
return [chunk.to_legacy_dict() for chunk in self.chunks]


class EmbeddingRecordDetail(BaseModel):
"""Lossless semantic record for a single ``embeddings_{model_tag}`` row (#510).

Mirrors the full embeddings schema (vector included) so a snapshot can be
restored by re-upserting every column exactly.
"""

model_config = ConfigDict(frozen=True)

collection: str = Field(..., description="Collection name for data isolation")
doc_id: str = Field(..., description="Document identifier inside the collection")
chunk_id: str = Field(..., description="Chunk identifier the vector belongs to")
parse_hash: str = Field(..., description="Parse version the embedding belongs to")
model: str = Field(..., description="Embedding model name")
vector: List[float] = Field(
default_factory=list, description="Embedding vector values"
)
vector_dimension: Optional[int] = Field(None, description="Stored vector dimension")
text: Optional[str] = Field(None, description="Original chunk text")
chunk_hash: Optional[str] = Field(None, description="Per-chunk content hash")
created_at: Optional[datetime] = Field(
None, description="Embedding creation timestamp"
)
metadata: Optional[str] = Field(None, description="Serialized chunk metadata")
user_id: Optional[int] = Field(
None, description="Owner user id for multi-tenancy (None for legacy data)"
)

@classmethod
def from_row(cls, row: Dict[str, Any]) -> "EmbeddingRecordDetail":
"""Build a record from a raw ``embeddings_{model_tag}`` table row dict."""
vector_dimension = _clean_row_value(row.get("vector_dimension"))
user_id = _clean_row_value(row.get("user_id"))
vector = _clean_row_value(row.get("vector"))
return cls(
collection=_clean_row_value(row.get("collection")),
doc_id=_clean_row_value(row.get("doc_id")),
chunk_id=_clean_row_value(row.get("chunk_id")),
parse_hash=_clean_row_value(row.get("parse_hash")),
model=_clean_row_value(row.get("model")),
vector=[float(v) for v in vector] if vector is not None else [],
vector_dimension=(
int(vector_dimension) if vector_dimension is not None else None
),
text=_clean_row_value(row.get("text")),
chunk_hash=_clean_row_value(row.get("chunk_hash")),
created_at=_clean_row_value(row.get("created_at")),
metadata=_clean_row_value(row.get("metadata")),
user_id=int(user_id) if user_id is not None else None,
)

def to_legacy_dict(self) -> Dict[str, Any]:
"""Return the legacy raw-row dict shape (all embeddings columns)."""
return {
"collection": self.collection,
"doc_id": self.doc_id,
"chunk_id": self.chunk_id,
"parse_hash": self.parse_hash,
"model": self.model,
"vector": self.vector,
"vector_dimension": self.vector_dimension,
"text": self.text,
"chunk_hash": self.chunk_hash,
"created_at": self.created_at,
"metadata": self.metadata,
"user_id": self.user_id,
}


class EmbeddingRecordSnapshot(BaseModel):
"""Ordered set of embedding rows captured for rollback restore (#510).

Embeddings live in per-model tables (``embeddings_{model_tag}``), so the
snapshot also exposes per-model-tag grouping for restore: each group routes
to its own ``upsert_embeddings`` call.
"""

model_config = ConfigDict(frozen=True)

records: List[EmbeddingRecordDetail] = Field(
default_factory=list, description="Embedding rows in their original order"
)

@classmethod
def from_rows(cls, rows: List[Dict[str, Any]]) -> "EmbeddingRecordSnapshot":
"""Build a snapshot from raw ``embeddings_{model_tag}`` table row dicts."""
return cls(records=[EmbeddingRecordDetail.from_row(row) for row in rows])

def to_legacy_dicts(self) -> List[Dict[str, Any]]:
"""Return the legacy ``list[dict]`` shape for all embedding rows."""
return [record.to_legacy_dict() for record in self.records]

def group_by_model_tag(self) -> Dict[str, List[Dict[str, Any]]]:
"""Group legacy row dicts by ``to_model_tag(model)`` for restore."""
from ..LanceDB.model_tag_utils import to_model_tag

grouped: Dict[str, List[Dict[str, Any]]] = {}
for record in self.records:
grouped.setdefault(to_model_tag(record.model), []).append(
record.to_legacy_dict()
)
return grouped


class DocumentOperationResult(BaseModel):
"""Standard response for document management operations."""

Expand Down
Loading
Loading