Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- `BaseVectorStoreDriver.query_vector` for querying vector stores with vectors.

## [1.1.1] - 2025-01-03

### Fixed

- Incorrect deprecation warning on `ToolkitTask`.

## [1.1.0] - 2024-12-31
Expand Down
2 changes: 1 addition & 1 deletion docs/griptape-framework/misc/events.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ Handler 2 <class 'griptape.events.finish_structure_run_event.FinishStructureRunE
You can use `Structure.run_stream()` for streaming Events from the `Structure` in the form of an iterator.

!!! tip
Set `stream=True` on your [Prompt Driver](../drivers/prompt-drivers.md) in order to receive completion chunk events.
Set `stream=True` on your [Prompt Driver](../drivers/prompt-drivers.md) in order to receive completion chunk events.

```python
--8<-- "docs/griptape-framework/misc/src/events_streaming.py"
Expand Down
9 changes: 4 additions & 5 deletions griptape/drivers/vector/astradb_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,19 +140,19 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto
for match in self.collection.find(filter=find_filter, projection={"*": 1})
]

def query(
def query_vector(
self,
query: str,
vector: list[float],
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
**kwargs: Any,
) -> list[BaseVectorStoreDriver.Entry]:
"""Run a similarity search on the Astra DB store, based on a query string.
"""Run a similarity search on the Astra DB store, based on a vector list.

Args:
query: the query string.
vector: the vector to be queried.
count: the maximum number of results to return. If omitted, defaults will apply.
namespace: the namespace to filter results by.
include_vectors: whether to include vector data in the results.
Expand All @@ -168,7 +168,6 @@ def query(
find_filter_ns: dict[str, Any] = {} if namespace is None else {"namespace": namespace}
find_filter = {**(query_filter or {}), **find_filter_ns}
find_projection: Optional[dict[str, int]] = {"*": 1} if include_vectors else None
vector = self.embedding_driver.embed_string(query)
ann_limit = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
matches = self.collection.find(
filter=find_filter,
Expand Down
29 changes: 23 additions & 6 deletions griptape/drivers/vector/azure_mongodb_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,22 @@
class AzureMongoDbVectorStoreDriver(MongoDbAtlasVectorStoreDriver):
"""A Vector Store Driver for CosmosDB with MongoDB vCore API."""

def query(
def query_vector(
self,
query: str,
vector: list[float],
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
offset: Optional[int] = None,
**kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
"""Queries the MongoDB collection for documents that match the provided query string.
"""Queries the MongoDB collection for documents that match the provided vector list.

Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index.
"""
collection = self.get_collection()

# Using the embedding driver to convert the query string into a vector
vector = self.embedding_driver.embed_string(query)

count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
offset = offset or 0

Expand Down Expand Up @@ -63,3 +60,23 @@
)
for doc in collection.aggregate(pipeline)
]

def query(
self,
query: str,
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
offset: Optional[int] = None,
**kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
"""Queries the MongoDB collection for documents that match the provided query string.

Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index.
"""
# Using the embedding driver to convert the query string into a vector
vector = self.embedding_driver.embed_string(query)
return self.query_vector(

Check warning on line 80 in griptape/drivers/vector/azure_mongodb_vector_store_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/vector/azure_mongodb_vector_store_driver.py#L79-L80

Added lines #L79 - L80 were not covered by tests
vector, count=count, namespace=namespace, include_vectors=include_vectors, offset=offset, **kwargs
)
17 changes: 15 additions & 2 deletions griptape/drivers/vector/base_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,18 @@ def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Opti
@abstractmethod
def load_entries(self, *, namespace: Optional[str] = None) -> list[Entry]: ...

@abstractmethod
def query_vector(
self,
vector: list[float],
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
**kwargs,
) -> list[Entry]:
# TODO: Mark as abstract method for griptape 2.0
raise NotImplementedError(f"{self.__class__.__name__} does not support vector query.")

def query(
self,
query: str,
Expand All @@ -148,7 +159,9 @@ def query(
namespace: Optional[str] = None,
include_vectors: bool = False,
**kwargs,
) -> list[Entry]: ...
) -> list[Entry]:
vector = self.embedding_driver.embed_string(query)
return self.query_vector(vector, count=count, namespace=namespace, include_vectors=include_vectors, **kwargs)

def _get_default_vector_id(self, value: str) -> str:
return str(uuid.uuid5(uuid.NAMESPACE_OID, value))
11 changes: 11 additions & 0 deletions griptape/drivers/vector/dummy_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,17 @@ def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Opti
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
raise DummyError(__class__.__name__, "load_entries")

def query_vector(
self,
vector: list[float],
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
**kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
raise DummyError(__class__.__name__, "query_vector")

def query(
self,
query: str,
Expand Down
8 changes: 3 additions & 5 deletions griptape/drivers/vector/local_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,24 +78,22 @@ def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Opti
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
return [entry for key, entry in self.entries.items() if namespace is None or entry.namespace == namespace]

def query(
def query_vector(
self,
query: str,
vector: list[float],
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
**kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
query_embedding = self.embedding_driver.embed_string(query)

if namespace:
entries = {k: v for (k, v) in self.entries.items() if k.startswith(f"{namespace}-")}
else:
entries = self.entries

entries_and_relatednesses = [
(entry, self.calculate_relatedness(query_embedding, entry.vector)) for entry in list(entries.values())
(entry, self.calculate_relatedness(vector, entry.vector)) for entry in list(entries.values())
]

entries_and_relatednesses.sort(key=operator.itemgetter(1), reverse=True)
Expand Down
29 changes: 23 additions & 6 deletions griptape/drivers/vector/mongodb_atlas_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,25 +114,22 @@
for doc in cursor
]

def query(
def query_vector(
self,
query: str,
vector: list[float],
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
offset: Optional[int] = None,
**kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
"""Queries the MongoDB collection for documents that match the provided query string.
"""Queries the MongoDB collection for documents that match the provided vector list.

Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index.
"""
collection = self.get_collection()

# Using the embedding driver to convert the query string into a vector
vector = self.embedding_driver.embed_string(query)

count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
offset = offset or 0

Expand Down Expand Up @@ -171,6 +168,26 @@
for doc in collection.aggregate(pipeline)
]

def query(
self,
query: str,
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
offset: Optional[int] = None,
**kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
"""Queries the MongoDB collection for documents that match the provided query string.

Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index.
"""
# Using the embedding driver to convert the query string into a vector
vector = self.embedding_driver.embed_string(query)
return self.query_vector(

Check warning on line 187 in griptape/drivers/vector/mongodb_atlas_vector_store_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/vector/mongodb_atlas_vector_store_driver.py#L186-L187

Added lines #L186 - L187 were not covered by tests
vector, count=count, namespace=namespace, include_vectors=include_vectors, offset=offset, **kwargs
)

def delete_vector(self, vector_id: str) -> None:
"""Deletes the vector from the collection."""
collection = self.get_collection()
Expand Down
36 changes: 32 additions & 4 deletions griptape/drivers/vector/opensearch_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@
for hit in response["hits"]["hits"]
]

def query(
def query_vector(
self,
query: str,
vector: list[float],
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
Expand All @@ -130,15 +130,14 @@
field_name: str = "vector",
**kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
"""Performs a nearest neighbor search on OpenSearch to find vectors similar to the provided query string.
"""Performs a nearest neighbor search on OpenSearch to find vectors similar to the provided vector list.

Results can be limited using the count parameter and optionally filtered by a namespace.

Returns:
A list of BaseVectorStoreDriver.Entry objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace.
"""
count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
vector = self.embedding_driver.embed_string(query)
# Base k-NN query
query_body = {"size": count, "query": {"knn": {field_name: {"vector": vector, "k": count}}}}

Expand All @@ -165,5 +164,34 @@
for hit in response["hits"]["hits"]
]

def query(
self,
query: str,
*,
count: Optional[int] = None,
namespace: Optional[str] = None,
include_vectors: bool = False,
include_metadata: bool = True,
field_name: str = "vector",
**kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
"""Performs a nearest neighbor search on OpenSearch to find vectors similar to the provided query string.

Results can be limited using the count parameter and optionally filtered by a namespace.

Returns:
A list of BaseVectorStoreDriver.Entry objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace.
"""
vector = self.embedding_driver.embed_string(query)
return self.query_vector(

Check warning on line 186 in griptape/drivers/vector/opensearch_vector_store_driver.py

View check run for this annotation

Codecov / codecov/patch

griptape/drivers/vector/opensearch_vector_store_driver.py#L185-L186

Added lines #L185 - L186 were not covered by tests
vector,
count=count,
namespace=namespace,
include_vectors=include_vectors,
include_metadata=include_metadata,
field_name=field_name,
**kwargs,
)

def delete_vector(self, vector_id: str) -> NoReturn:
raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")
27 changes: 23 additions & 4 deletions griptape/drivers/vector/pgvector_vector_store_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto
for result in results
]

def query(
def query_vector(
self,
query: str,
vector: list[float],
*,
count: Optional[int] = BaseVectorStoreDriver.DEFAULT_QUERY_COUNT,
namespace: Optional[str] = None,
Expand All @@ -152,8 +152,6 @@ def query(
op = distance_metrics[distance_metric]

with sqlalchemy_orm.Session(self.engine) as session:
vector = self.embedding_driver.embed_string(query)

# The query should return both the vector and the distance metric score.
query_result = session.query(self._model, op(vector).label("score")).order_by(op(vector)) # pyright: ignore[reportOptionalCall]

Expand Down Expand Up @@ -182,6 +180,27 @@ def query(
for result in results
]

def query(
self,
query: str,
*,
count: Optional[int] = BaseVectorStoreDriver.DEFAULT_QUERY_COUNT,
namespace: Optional[str] = None,
include_vectors: bool = False,
distance_metric: str = "cosine_distance",
**kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
"""Performs a search on the collection to find vectors similar to the provided input vector, optionally filtering to only those that match the provided namespace."""
vector = self.embedding_driver.embed_string(query)
return self.query_vector(
vector,
count=count,
namespace=namespace,
include_vectors=include_vectors,
distance_metric=distance_metric,
**kwargs,
)

def default_vector_model(self) -> Any:
pgvector_sqlalchemy = import_optional_dependency("pgvector.sqlalchemy")
sqlalchemy = import_optional_dependency("sqlalchemy")
Expand Down
Loading
Loading