-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathschemas.py
More file actions
167 lines (128 loc) · 4.76 KB
/
schemas.py
File metadata and controls
167 lines (128 loc) · 4.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
from __future__ import annotations
from datetime import date, datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, field_validator, model_validator
class DocType(str, Enum):
"""High-level document categories used for retrieval and ranking."""
GENERAL = "general"
PLAN = "plan"
EXCEPTION = "exception"
REGION = "region"
AUDIENCE = "audience"
PROCESS = "process"
FAQ = "faq"
class InsuranceDocument(BaseModel):
"""Raw document as stored in the dataset."""
doc_id: str = Field(..., min_length=1)
title: str = Field(..., min_length=1)
doc_type: DocType
plan: Optional[str] = None
region: Optional[str] = None
audience: List[str] = Field(default_factory=list)
topic: List[str] = Field(default_factory=list)
priority: int = Field(..., ge=1, le=5)
valid_from: date
valid_to: Optional[date] = None
content: str = Field(..., min_length=1)
@field_validator("plan", "region")
@classmethod
def normalize_optional_strings(cls, value: Optional[str]) -> Optional[str]:
if value is None:
return None
normalized = value.strip().lower()
return normalized or None
@field_validator("audience", "topic")
@classmethod
def normalize_string_lists(cls, values: List[str]) -> List[str]:
normalized: List[str] = []
seen = set()
for value in values:
item = value.strip().lower()
if item and item not in seen:
normalized.append(item)
seen.add(item)
return normalized
@field_validator("content")
@classmethod
def normalize_content(cls, value: str) -> str:
cleaned = " ".join(value.split())
if not cleaned:
raise ValueError("content must not be empty")
return cleaned
@model_validator(mode="after")
def validate_date_range(self) -> "InsuranceDocument":
if self.valid_to is not None and self.valid_to < self.valid_from:
raise ValueError("valid_to cannot be earlier than valid_from")
return self
class DocumentChunk(BaseModel):
"""Chunked retrieval unit derived from an InsuranceDocument."""
chunk_id: str
doc_id: str
title: str
doc_type: DocType
text: str
chunk_index: int = Field(..., ge=0)
plan: Optional[str] = None
region: Optional[str] = None
audience: List[str] = Field(default_factory=list)
topic: List[str] = Field(default_factory=list)
priority: int = Field(..., ge=1, le=5)
valid_from: date
valid_to: Optional[date] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
@field_validator("text")
@classmethod
def normalize_text(cls, value: str) -> str:
cleaned = " ".join(value.split())
if not cleaned:
raise ValueError("chunk text must not be empty")
return cleaned
class AskRequest(BaseModel):
user_id: str = Field(..., min_length=1)
question: str = Field(..., min_length=1)
@field_validator("user_id", "question")
@classmethod
def strip_required_strings(cls, value: str) -> str:
cleaned = value.strip()
if not cleaned:
raise ValueError("value must not be empty")
return cleaned
class AskResponse(BaseModel):
answer: str
sources: List[str]
class RetrievalIntent(str, Enum):
GENERAL_INFO = "general_info"
ELIGIBILITY = "eligibility"
EXCEPTION_CHECK = "exception_check"
COVERAGE_CHECK = "coverage_check"
NEXT_STEPS = "next_steps"
FOLLOW_UP = "follow_up"
COMPARISON = "comparison"
class RetrievalPlan(BaseModel):
"""Structured output of the lightweight agentic planner."""
intent: RetrievalIntent
entities: Dict[str, Any] = Field(default_factory=dict)
retrieve_doc_types: List[DocType] = Field(default_factory=list)
search_topics: List[str] = Field(default_factory=list)
needs_second_pass: bool = False
planner_notes: Optional[str] = None
class RetrievedChunk(BaseModel):
"""Chunk with search and reranking signals attached."""
chunk: DocumentChunk
vector_score: float = 0.0
bm25_score: float = 0.0
rerank_score: float = 0.0
match_reasons: List[str] = Field(default_factory=list)
class ConversationTurn(BaseModel):
role: str
content: str
timestamp: datetime = Field(default_factory=datetime.utcnow)
class UserMemory(BaseModel):
user_id: str
facts: Dict[str, Any] = Field(default_factory=dict)
history: List[ConversationTurn] = Field(default_factory=list)
def add_turn(self, role: str, content: str) -> None:
self.history.append(ConversationTurn(role=role, content=content))
def trimmed_history(self, max_turns: int = 8) -> List[ConversationTurn]:
return self.history[-max_turns:]