Skip to content

Commit aa265e0

Browse files
Josh/clean up opensearch classes (#381)
## Description Remove some duplicated models being used in lambda_handler. ## Related Issues Closes n/a
1 parent d1797a9 commit aa265e0

File tree

2 files changed

+42
-129
lines changed
  • packages

2 files changed

+42
-129
lines changed
Lines changed: 30 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,77 +1,66 @@
1-
import pydantic
1+
from pydantic import BaseModel
2+
from pydantic import Field
23

34

4-
class S3Location(pydantic.BaseModel):
5+
class S3Location(BaseModel):
56
"""Represents the location of a file in S3, indicating which file contained the relevant data."""
67

7-
bucket: str = pydantic.Field(description="The S3 bucket where the file is located.")
8-
key: str = pydantic.Field(description="The S3 key (path) where the file is located.")
8+
bucket: str = Field(description="The S3 bucket where the file is located.")
9+
key: str = Field(description="The S3 key (path) where the file is located.")
910

1011

11-
class OpenSearchHitSource(pydantic.BaseModel):
12+
class OpenSearchHitSource(BaseModel):
1213
"""Represents a single search result _source returned from OpenSearch."""
1314

14-
id: int = pydantic.Field(
15-
description="The unique ID from the embedding data of the search result hit."
16-
)
17-
loinc_code: str = pydantic.Field(description="The LOINC code of the search result hit.")
18-
loinc_name_type: str = pydantic.Field(
19-
description="The LOINC name type of the search result hit."
20-
)
21-
description: str = pydantic.Field(description="The description of the search result hit.")
22-
loinc_type: str = pydantic.Field(description="The LOINC type of the search result hit.")
23-
s3: S3Location = pydantic.Field(description="The S3 location of the search result hit.")
15+
id: int = Field(description="The unique ID from the embedding data of the search result hit.")
16+
loinc_code: str = Field(description="The LOINC code of the search result hit.")
17+
loinc_name_type: str = Field(description="The LOINC name type of the search result hit.")
18+
description: str = Field(description="The description of the search result hit.")
19+
loinc_type: str = Field(description="The LOINC type of the search result hit.")
20+
s3: S3Location = Field(description="The S3 location of the search result hit.")
2421

2522

26-
class OpenSearchHit(pydantic.BaseModel):
23+
class OpenSearchHit(BaseModel):
2724
"""Represents a single search result hit returned from OpenSearch."""
2825

29-
index: str = pydantic.Field(
26+
index: str = Field(
3027
description="The index that the search result hit came from.", alias="_index"
3128
)
32-
id: str = pydantic.Field(
33-
description="The unique OpenSearch ID of the search result hit.", alias="_id"
34-
)
35-
score: float = pydantic.Field(
29+
id: str = Field(description="The unique OpenSearch ID of the search result hit.", alias="_id")
30+
score: float = Field(
3631
description="The cosine similarity score of the search result hit.", alias="_score"
3732
)
38-
source: OpenSearchHitSource = pydantic.Field(
33+
source: OpenSearchHitSource = Field(
3934
description="The source of the search result hit.", alias="_source"
4035
)
4136

4237

43-
class OpenSearchHits(pydantic.BaseModel):
38+
class OpenSearchHits(BaseModel):
4439
"""Represents all of the search result hits returned from OpenSearch."""
4540

46-
total_hits: dict[str, int] = pydantic.Field(
41+
total_hits: dict[str, int] = Field(
4742
alias="total", description="The total number of hits returned from OpenSearch."
4843
)
49-
hits: list[OpenSearchHit] = pydantic.Field(
44+
hits: list[OpenSearchHit] = Field(
5045
description="The list of search result hits returned from OpenSearch."
5146
)
5247

5348

54-
class OpenSearchShards(pydantic.BaseModel):
49+
class OpenSearchShards(BaseModel):
5550
"""Represents the shard information returned from OpenSearch."""
5651

57-
total: int = pydantic.Field(description="The total number of shards involved in the search.")
58-
successful: int = pydantic.Field(
59-
description="The number of shards that successfully returned results."
60-
)
61-
skipped: int = pydantic.Field(
62-
description="The number of shards that were skipped during the search."
63-
)
64-
failed: int = pydantic.Field(description="The number of shards that failed to return results.")
52+
total: int = Field(description="The total number of shards involved in the search.")
53+
successful: int = Field(description="The number of shards that successfully returned results.")
54+
skipped: int = Field(description="The number of shards that were skipped during the search.")
55+
failed: int = Field(description="The number of shards that failed to return results.")
6556

6657

67-
class OpenSearchResult(pydantic.BaseModel):
58+
class OpenSearchResult(BaseModel):
6859
"""Represents the overall search result returned from OpenSearch, including hits and shard information."""
6960

70-
took: int = pydantic.Field(description="The time taken to execute the search in milliseconds.")
71-
timed_out: bool = pydantic.Field(description="Indicates whether the search timed out.")
72-
shards: OpenSearchShards = pydantic.Field(
61+
took: int = Field(description="The time taken to execute the search in milliseconds.")
62+
timed_out: bool = Field(description="Indicates whether the search timed out.")
63+
shards: OpenSearchShards = Field(
7364
description="The shard information for the search.", alias="_shards"
7465
)
75-
hits: OpenSearchHits = pydantic.Field(
76-
description="The search result hits returned from OpenSearch."
77-
)
66+
hits: OpenSearchHits = Field(description="The search result hits returned from OpenSearch.")
Lines changed: 12 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import typing
22

3-
import pydantic
3+
from pydantic import BaseModel
4+
from pydantic import Field
5+
from pydantic import model_validator
46

57
from shared_models import DataField
68

@@ -22,110 +24,32 @@ def to_filter_values(cls, data_field: DataField) -> list[str]:
2224
raise ValueError(f"No type mapping defined for {data_field}") from err
2325

2426

25-
class VectorSearchParams(pydantic.BaseModel):
27+
class VectorSearchParams(BaseModel):
2628
"""Parameters for performing a vector search."""
2729

28-
vector: list[float] = pydantic.Field(description="The vector to search for.")
29-
vector_field: str = pydantic.Field(
30+
vector: list[float] = Field(description="The vector to search for.")
31+
vector_field: str = Field(
3032
default="descriptionVector", description="The field to perform the vector search on."
3133
)
32-
filter_field: str = pydantic.Field(
33-
default="type", description="The field to filter on, e.g., 'type'."
34-
)
35-
data_field: DataField = pydantic.Field(
34+
filter_field: str = Field(default="type", description="The field to filter on, e.g., 'type'.")
35+
data_field: DataField = Field(
3636
description="The value of the field to filter on, e.g., 'Lab Test Name Ordered' or 'Lab Test Name Resulted'."
3737
)
38-
size: int = pydantic.Field(default=10, description="The number of results to retrieve.")
39-
k: int = pydantic.Field(
38+
size: int = Field(default=10, description="The number of results to retrieve.")
39+
k: int = Field(
4040
default=10, description="The number of nearest neighbors to examine during the query."
4141
)
42-
filter_value: list[str] = pydantic.Field(
42+
filter_value: list[str] = Field(
4343
default_factory=list,
4444
init=False,
4545
description="The list of filter values corresponding to the data_field, computed after initialization.",
4646
)
4747

48-
@pydantic.model_validator(mode="after")
48+
@model_validator(mode="after")
4949
def compute_filter_value(self) -> "VectorSearchParams":
5050
"""Uses the DataFieldTypeMapping to get the filter values corresponding to the data_field."""
5151
if self.filter_field == type(self).model_fields["filter_field"].default:
5252
self.filter_value = DataFieldTypeMapping.to_filter_values(self.data_field)
5353
else:
5454
raise ValueError(f"Unsupported filter field: {self.filter_field}")
5555
return self
56-
57-
58-
class S3Location(pydantic.BaseModel):
59-
"""Represents the location of a file in S3, indicating which file contained the relevant data."""
60-
61-
bucket: str = pydantic.Field(description="The S3 bucket where the file is located.")
62-
key: str = pydantic.Field(description="The S3 key (path) where the file is located.")
63-
64-
65-
class OpenSearchHitSource(pydantic.BaseModel):
66-
"""Represents a single search result _source returned from OpenSearch."""
67-
68-
id: int = pydantic.Field(
69-
description="The unique ID from the embedding data of the search result hit."
70-
)
71-
loinc_code: str = pydantic.Field(description="The LOINC code of the search result hit.")
72-
loinc_name_type: str = pydantic.Field(
73-
description="The LOINC name type of the search result hit."
74-
)
75-
description: str = pydantic.Field(description="The description of the search result hit.")
76-
loinc_type: str = pydantic.Field(description="The LOINC type of the search result hit.")
77-
s3: S3Location = pydantic.Field(description="The S3 location of the search result hit.")
78-
79-
80-
class OpenSearchHit(pydantic.BaseModel):
81-
"""Represents a single search result hit returned from OpenSearch."""
82-
83-
index: str = pydantic.Field(
84-
description="The index that the search result hit came from.", alias="_index"
85-
)
86-
id: str = pydantic.Field(
87-
description="The unique OpenSearch ID of the search result hit.", alias="_id"
88-
)
89-
score: float = pydantic.Field(
90-
description="The cosine similarity score of the search result hit.", alias="_score"
91-
)
92-
source: OpenSearchHitSource = pydantic.Field(
93-
description="The source of the search result hit.", alias="_source"
94-
)
95-
96-
97-
class OpenSearchHits(pydantic.BaseModel):
98-
"""Represents all of the search result hits returned from OpenSearch."""
99-
100-
total_hits: dict[str, int] = pydantic.Field(
101-
alias="total", description="The total number of hits returned from OpenSearch."
102-
)
103-
hits: list[OpenSearchHit] = pydantic.Field(
104-
description="The list of search result hits returned from OpenSearch."
105-
)
106-
107-
108-
class OpenSearchShards(pydantic.BaseModel):
109-
"""Represents the shard information returned from OpenSearch."""
110-
111-
total: int = pydantic.Field(description="The total number of shards involved in the search.")
112-
successful: int = pydantic.Field(
113-
description="The number of shards that successfully returned results."
114-
)
115-
skipped: int = pydantic.Field(
116-
description="The number of shards that were skipped during the search."
117-
)
118-
failed: int = pydantic.Field(description="The number of shards that failed to return results.")
119-
120-
121-
class OpenSearchResult(pydantic.BaseModel):
122-
"""Represents the overall search result returned from OpenSearch, including hits and shard information."""
123-
124-
took: int = pydantic.Field(description="The time taken to execute the search in milliseconds.")
125-
timed_out: bool = pydantic.Field(description="Indicates whether the search timed out.")
126-
shards: OpenSearchShards = pydantic.Field(
127-
description="The shard information for the search.", alias="_shards"
128-
)
129-
hits: OpenSearchHits = pydantic.Field(
130-
description="The search result hits returned from OpenSearch."
131-
)

0 commit comments

Comments
 (0)