11import typing
22
3- import pydantic
3+ from pydantic import BaseModel
4+ from pydantic import Field
5+ from pydantic import model_validator
46
57from shared_models import DataField
68
@@ -22,110 +24,32 @@ def to_filter_values(cls, data_field: DataField) -> list[str]:
2224 raise ValueError (f"No type mapping defined for { data_field } " ) from err
2325
2426
25- class VectorSearchParams (pydantic . BaseModel ):
27+ class VectorSearchParams (BaseModel ):
2628 """Parameters for performing a vector search."""
2729
28- vector : list [float ] = pydantic . Field (description = "The vector to search for." )
29- vector_field : str = pydantic . Field (
30+ vector : list [float ] = Field (description = "The vector to search for." )
31+ vector_field : str = Field (
3032 default = "descriptionVector" , description = "The field to perform the vector search on."
3133 )
32- filter_field : str = pydantic .Field (
33- default = "type" , description = "The field to filter on, e.g., 'type'."
34- )
35- data_field : DataField = pydantic .Field (
34+ filter_field : str = Field (default = "type" , description = "The field to filter on, e.g., 'type'." )
35+ data_field : DataField = Field (
3636 description = "The value of the field to filter on, e.g., 'Lab Test Name Ordered' or 'Lab Test Name Resulted'."
3737 )
38- size : int = pydantic . Field (default = 10 , description = "The number of results to retrieve." )
39- k : int = pydantic . Field (
38+ size : int = Field (default = 10 , description = "The number of results to retrieve." )
39+ k : int = Field (
4040 default = 10 , description = "The number of nearest neighbors to examine during the query."
4141 )
42- filter_value : list [str ] = pydantic . Field (
42+ filter_value : list [str ] = Field (
4343 default_factory = list ,
4444 init = False ,
4545 description = "The list of filter values corresponding to the data_field, computed after initialization." ,
4646 )
4747
48- @pydantic . model_validator (mode = "after" )
48+ @model_validator (mode = "after" )
4949 def compute_filter_value (self ) -> "VectorSearchParams" :
5050 """Uses the DataFieldTypeMapping to get the filter values corresponding to the data_field."""
5151 if self .filter_field == type (self ).model_fields ["filter_field" ].default :
5252 self .filter_value = DataFieldTypeMapping .to_filter_values (self .data_field )
5353 else :
5454 raise ValueError (f"Unsupported filter field: { self .filter_field } " )
5555 return self
56-
57-
58- class S3Location (pydantic .BaseModel ):
59- """Represents the location of a file in S3, indicating which file contained the relevant data."""
60-
61- bucket : str = pydantic .Field (description = "The S3 bucket where the file is located." )
62- key : str = pydantic .Field (description = "The S3 key (path) where the file is located." )
63-
64-
65- class OpenSearchHitSource (pydantic .BaseModel ):
66- """Represents a single search result _source returned from OpenSearch."""
67-
68- id : int = pydantic .Field (
69- description = "The unique ID from the embedding data of the search result hit."
70- )
71- loinc_code : str = pydantic .Field (description = "The LOINC code of the search result hit." )
72- loinc_name_type : str = pydantic .Field (
73- description = "The LOINC name type of the search result hit."
74- )
75- description : str = pydantic .Field (description = "The description of the search result hit." )
76- loinc_type : str = pydantic .Field (description = "The LOINC type of the search result hit." )
77- s3 : S3Location = pydantic .Field (description = "The S3 location of the search result hit." )
78-
79-
80- class OpenSearchHit (pydantic .BaseModel ):
81- """Represents a single search result hit returned from OpenSearch."""
82-
83- index : str = pydantic .Field (
84- description = "The index that the search result hit came from." , alias = "_index"
85- )
86- id : str = pydantic .Field (
87- description = "The unique OpenSearch ID of the search result hit." , alias = "_id"
88- )
89- score : float = pydantic .Field (
90- description = "The cosine similarity score of the search result hit." , alias = "_score"
91- )
92- source : OpenSearchHitSource = pydantic .Field (
93- description = "The source of the search result hit." , alias = "_source"
94- )
95-
96-
97- class OpenSearchHits (pydantic .BaseModel ):
98- """Represents all of the search result hits returned from OpenSearch."""
99-
100- total_hits : dict [str , int ] = pydantic .Field (
101- alias = "total" , description = "The total number of hits returned from OpenSearch."
102- )
103- hits : list [OpenSearchHit ] = pydantic .Field (
104- description = "The list of search result hits returned from OpenSearch."
105- )
106-
107-
108- class OpenSearchShards (pydantic .BaseModel ):
109- """Represents the shard information returned from OpenSearch."""
110-
111- total : int = pydantic .Field (description = "The total number of shards involved in the search." )
112- successful : int = pydantic .Field (
113- description = "The number of shards that successfully returned results."
114- )
115- skipped : int = pydantic .Field (
116- description = "The number of shards that were skipped during the search."
117- )
118- failed : int = pydantic .Field (description = "The number of shards that failed to return results." )
119-
120-
121- class OpenSearchResult (pydantic .BaseModel ):
122- """Represents the overall search result returned from OpenSearch, including hits and shard information."""
123-
124- took : int = pydantic .Field (description = "The time taken to execute the search in milliseconds." )
125- timed_out : bool = pydantic .Field (description = "Indicates whether the search timed out." )
126- shards : OpenSearchShards = pydantic .Field (
127- description = "The shard information for the search." , alias = "_shards"
128- )
129- hits : OpenSearchHits = pydantic .Field (
130- description = "The search result hits returned from OpenSearch."
131- )
0 commit comments