Skip to content

Commit 4516534

Browse files
authored
Implement query_vector() for all vector_store_drivers (#1494)
1 parent 36394cc commit 4516534

24 files changed

+351
-47
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## Unreleased
99

10+
### Added
11+
12+
- `BaseVectorStoreDriver.query_vector` for querying vector stores with vectors.
13+
1014
## [1.1.1] - 2025-01-03
1115

1216
### Fixed
17+
1318
- Incorrect deprecation warning on `ToolkitTask`.
1419

1520
## [1.1.0] - 2024-12-31

docs/griptape-framework/misc/events.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ Handler 2 <class 'griptape.events.finish_structure_run_event.FinishStructureRunE
7979
You can use `Structure.run_stream()` for streaming Events from the `Structure` in the form of an iterator.
8080

8181
!!! tip
82-
Set `stream=True` on your [Prompt Driver](../drivers/prompt-drivers.md) in order to receive completion chunk events.
82+
Set `stream=True` on your [Prompt Driver](../drivers/prompt-drivers.md) in order to receive completion chunk events.
8383

8484
```python
8585
--8<-- "docs/griptape-framework/misc/src/events_streaming.py"

griptape/drivers/vector/astradb_vector_store_driver.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,19 +140,19 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto
140140
for match in self.collection.find(filter=find_filter, projection={"*": 1})
141141
]
142142

143-
def query(
143+
def query_vector(
144144
self,
145-
query: str,
145+
vector: list[float],
146146
*,
147147
count: Optional[int] = None,
148148
namespace: Optional[str] = None,
149149
include_vectors: bool = False,
150150
**kwargs: Any,
151151
) -> list[BaseVectorStoreDriver.Entry]:
152-
"""Run a similarity search on the Astra DB store, based on a query string.
152+
"""Run a similarity search on the Astra DB store, based on a vector list.
153153
154154
Args:
155-
query: the query string.
155+
vector: the vector to be queried.
156156
count: the maximum number of results to return. If omitted, defaults will apply.
157157
namespace: the namespace to filter results by.
158158
include_vectors: whether to include vector data in the results.
@@ -168,7 +168,6 @@ def query(
168168
find_filter_ns: dict[str, Any] = {} if namespace is None else {"namespace": namespace}
169169
find_filter = {**(query_filter or {}), **find_filter_ns}
170170
find_projection: Optional[dict[str, int]] = {"*": 1} if include_vectors else None
171-
vector = self.embedding_driver.embed_string(query)
172171
ann_limit = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
173172
matches = self.collection.find(
174173
filter=find_filter,

griptape/drivers/vector/azure_mongodb_vector_store_driver.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,22 @@
1111
class AzureMongoDbVectorStoreDriver(MongoDbAtlasVectorStoreDriver):
1212
"""A Vector Store Driver for CosmosDB with MongoDB vCore API."""
1313

14-
def query(
14+
def query_vector(
1515
self,
16-
query: str,
16+
vector: list[float],
1717
*,
1818
count: Optional[int] = None,
1919
namespace: Optional[str] = None,
2020
include_vectors: bool = False,
2121
offset: Optional[int] = None,
2222
**kwargs,
2323
) -> list[BaseVectorStoreDriver.Entry]:
24-
"""Queries the MongoDB collection for documents that match the provided query string.
24+
"""Queries the MongoDB collection for documents that match the provided vector list.
2525
2626
Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index.
2727
"""
2828
collection = self.get_collection()
2929

30-
# Using the embedding driver to convert the query string into a vector
31-
vector = self.embedding_driver.embed_string(query)
32-
3330
count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
3431
offset = offset or 0
3532

@@ -63,3 +60,23 @@ def query(
6360
)
6461
for doc in collection.aggregate(pipeline)
6562
]
63+
64+
def query(
65+
self,
66+
query: str,
67+
*,
68+
count: Optional[int] = None,
69+
namespace: Optional[str] = None,
70+
include_vectors: bool = False,
71+
offset: Optional[int] = None,
72+
**kwargs,
73+
) -> list[BaseVectorStoreDriver.Entry]:
74+
"""Queries the MongoDB collection for documents that match the provided query string.
75+
76+
Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index.
77+
"""
78+
# Using the embedding driver to convert the query string into a vector
79+
vector = self.embedding_driver.embed_string(query)
80+
return self.query_vector(
81+
vector, count=count, namespace=namespace, include_vectors=include_vectors, offset=offset, **kwargs
82+
)

griptape/drivers/vector/base_vector_store_driver.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,18 @@ def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Opti
139139
@abstractmethod
140140
def load_entries(self, *, namespace: Optional[str] = None) -> list[Entry]: ...
141141

142-
@abstractmethod
142+
def query_vector(
143+
self,
144+
vector: list[float],
145+
*,
146+
count: Optional[int] = None,
147+
namespace: Optional[str] = None,
148+
include_vectors: bool = False,
149+
**kwargs,
150+
) -> list[Entry]:
151+
# TODO: Mark as abstract method for griptape 2.0
152+
raise NotImplementedError(f"{self.__class__.__name__} does not support vector query.")
153+
143154
def query(
144155
self,
145156
query: str,
@@ -148,7 +159,9 @@ def query(
148159
namespace: Optional[str] = None,
149160
include_vectors: bool = False,
150161
**kwargs,
151-
) -> list[Entry]: ...
162+
) -> list[Entry]:
163+
vector = self.embedding_driver.embed_string(query)
164+
return self.query_vector(vector, count=count, namespace=namespace, include_vectors=include_vectors, **kwargs)
152165

153166
def _get_default_vector_id(self, value: str) -> str:
154167
return str(uuid.uuid5(uuid.NAMESPACE_OID, value))

griptape/drivers/vector/dummy_vector_store_driver.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,17 @@ def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Opti
3535
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
3636
raise DummyError(__class__.__name__, "load_entries")
3737

38+
def query_vector(
39+
self,
40+
vector: list[float],
41+
*,
42+
count: Optional[int] = None,
43+
namespace: Optional[str] = None,
44+
include_vectors: bool = False,
45+
**kwargs,
46+
) -> list[BaseVectorStoreDriver.Entry]:
47+
raise DummyError(__class__.__name__, "query_vector")
48+
3849
def query(
3950
self,
4051
query: str,

griptape/drivers/vector/local_vector_store_driver.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,24 +78,22 @@ def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Opti
7878
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
7979
return [entry for key, entry in self.entries.items() if namespace is None or entry.namespace == namespace]
8080

81-
def query(
81+
def query_vector(
8282
self,
83-
query: str,
83+
vector: list[float],
8484
*,
8585
count: Optional[int] = None,
8686
namespace: Optional[str] = None,
8787
include_vectors: bool = False,
8888
**kwargs,
8989
) -> list[BaseVectorStoreDriver.Entry]:
90-
query_embedding = self.embedding_driver.embed_string(query)
91-
9290
if namespace:
9391
entries = {k: v for (k, v) in self.entries.items() if k.startswith(f"{namespace}-")}
9492
else:
9593
entries = self.entries
9694

9795
entries_and_relatednesses = [
98-
(entry, self.calculate_relatedness(query_embedding, entry.vector)) for entry in list(entries.values())
96+
(entry, self.calculate_relatedness(vector, entry.vector)) for entry in list(entries.values())
9997
]
10098

10199
entries_and_relatednesses.sort(key=operator.itemgetter(1), reverse=True)

griptape/drivers/vector/mongodb_atlas_vector_store_driver.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,25 +114,22 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto
114114
for doc in cursor
115115
]
116116

117-
def query(
117+
def query_vector(
118118
self,
119-
query: str,
119+
vector: list[float],
120120
*,
121121
count: Optional[int] = None,
122122
namespace: Optional[str] = None,
123123
include_vectors: bool = False,
124124
offset: Optional[int] = None,
125125
**kwargs,
126126
) -> list[BaseVectorStoreDriver.Entry]:
127-
"""Queries the MongoDB collection for documents that match the provided query string.
127+
"""Queries the MongoDB collection for documents that match the provided vector list.
128128
129129
Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index.
130130
"""
131131
collection = self.get_collection()
132132

133-
# Using the embedding driver to convert the query string into a vector
134-
vector = self.embedding_driver.embed_string(query)
135-
136133
count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
137134
offset = offset or 0
138135

@@ -171,6 +168,26 @@ def query(
171168
for doc in collection.aggregate(pipeline)
172169
]
173170

171+
def query(
172+
self,
173+
query: str,
174+
*,
175+
count: Optional[int] = None,
176+
namespace: Optional[str] = None,
177+
include_vectors: bool = False,
178+
offset: Optional[int] = None,
179+
**kwargs,
180+
) -> list[BaseVectorStoreDriver.Entry]:
181+
"""Queries the MongoDB collection for documents that match the provided query string.
182+
183+
Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index.
184+
"""
185+
# Using the embedding driver to convert the query string into a vector
186+
vector = self.embedding_driver.embed_string(query)
187+
return self.query_vector(
188+
vector, count=count, namespace=namespace, include_vectors=include_vectors, offset=offset, **kwargs
189+
)
190+
174191
def delete_vector(self, vector_id: str) -> None:
175192
"""Deletes the vector from the collection."""
176193
collection = self.get_collection()

griptape/drivers/vector/opensearch_vector_store_driver.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,9 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto
119119
for hit in response["hits"]["hits"]
120120
]
121121

122-
def query(
122+
def query_vector(
123123
self,
124-
query: str,
124+
vector: list[float],
125125
*,
126126
count: Optional[int] = None,
127127
namespace: Optional[str] = None,
@@ -130,15 +130,14 @@ def query(
130130
field_name: str = "vector",
131131
**kwargs,
132132
) -> list[BaseVectorStoreDriver.Entry]:
133-
"""Performs a nearest neighbor search on OpenSearch to find vectors similar to the provided query string.
133+
"""Performs a nearest neighbor search on OpenSearch to find vectors similar to the provided vector list.
134134
135135
Results can be limited using the count parameter and optionally filtered by a namespace.
136136
137137
Returns:
138138
A list of BaseVectorStoreDriver.Entry objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace.
139139
"""
140140
count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
141-
vector = self.embedding_driver.embed_string(query)
142141
# Base k-NN query
143142
query_body = {"size": count, "query": {"knn": {field_name: {"vector": vector, "k": count}}}}
144143

@@ -165,5 +164,34 @@ def query(
165164
for hit in response["hits"]["hits"]
166165
]
167166

167+
def query(
168+
self,
169+
query: str,
170+
*,
171+
count: Optional[int] = None,
172+
namespace: Optional[str] = None,
173+
include_vectors: bool = False,
174+
include_metadata: bool = True,
175+
field_name: str = "vector",
176+
**kwargs,
177+
) -> list[BaseVectorStoreDriver.Entry]:
178+
"""Performs a nearest neighbor search on OpenSearch to find vectors similar to the provided query string.
179+
180+
Results can be limited using the count parameter and optionally filtered by a namespace.
181+
182+
Returns:
183+
A list of BaseVectorStoreDriver.Entry objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace.
184+
"""
185+
vector = self.embedding_driver.embed_string(query)
186+
return self.query_vector(
187+
vector,
188+
count=count,
189+
namespace=namespace,
190+
include_vectors=include_vectors,
191+
include_metadata=include_metadata,
192+
field_name=field_name,
193+
**kwargs,
194+
)
195+
168196
def delete_vector(self, vector_id: str) -> NoReturn:
169197
raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

griptape/drivers/vector/pgvector_vector_store_driver.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,9 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto
127127
for result in results
128128
]
129129

130-
def query(
130+
def query_vector(
131131
self,
132-
query: str,
132+
vector: list[float],
133133
*,
134134
count: Optional[int] = BaseVectorStoreDriver.DEFAULT_QUERY_COUNT,
135135
namespace: Optional[str] = None,
@@ -152,8 +152,6 @@ def query(
152152
op = distance_metrics[distance_metric]
153153

154154
with sqlalchemy_orm.Session(self.engine) as session:
155-
vector = self.embedding_driver.embed_string(query)
156-
157155
# The query should return both the vector and the distance metric score.
158156
query_result = session.query(self._model, op(vector).label("score")).order_by(op(vector)) # pyright: ignore[reportOptionalCall]
159157

@@ -182,6 +180,27 @@ def query(
182180
for result in results
183181
]
184182

183+
def query(
184+
self,
185+
query: str,
186+
*,
187+
count: Optional[int] = BaseVectorStoreDriver.DEFAULT_QUERY_COUNT,
188+
namespace: Optional[str] = None,
189+
include_vectors: bool = False,
190+
distance_metric: str = "cosine_distance",
191+
**kwargs,
192+
) -> list[BaseVectorStoreDriver.Entry]:
193+
"""Performs a search on the collection to find vectors similar to the provided input vector, optionally filtering to only those that match the provided namespace."""
194+
vector = self.embedding_driver.embed_string(query)
195+
return self.query_vector(
196+
vector,
197+
count=count,
198+
namespace=namespace,
199+
include_vectors=include_vectors,
200+
distance_metric=distance_metric,
201+
**kwargs,
202+
)
203+
185204
def default_vector_model(self) -> Any:
186205
pgvector_sqlalchemy = import_optional_dependency("pgvector.sqlalchemy")
187206
sqlalchemy = import_optional_dependency("sqlalchemy")

0 commit comments

Comments
 (0)