Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: Add vector search to Postgres connector #10213

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
8fdb654
Add vectorized search to postgres connector
lossyrob Jan 14, 2025
db0e305
Allow kwargs and connection class in settings
lossyrob Jan 14, 2025
7762f28
Don't validate connstr in case custom connector
lossyrob Jan 14, 2025
95ea77e
Fix issue with closing connection pools
lossyrob Jan 15, 2025
54ecdf5
Allow conversion of additional selected columns
lossyrob Jan 16, 2025
0f3cb9a
Return total search count if requested
lossyrob Jan 16, 2025
be47d6f
Add integration test for postgres vector search
lossyrob Jan 16, 2025
96b14e5
Add vector search to postgres example notebook
lossyrob Jan 16, 2025
4a65db5
Rerun notebook to remove typo in gpt-4o response
lossyrob Jan 16, 2025
94a0800
Use shortcut import paths in sample
lossyrob Jan 22, 2025
4d5f4b9
Lint: Fix imports
lossyrob Jan 22, 2025
e2ac0a5
Ensure distance column name does not conflict with model field
lossyrob Jan 22, 2025
8884395
Use named placeholders in sql.SQL.format
lossyrob Jan 22, 2025
5391b8c
Remove unused fields copy; avoid including any vectors on option
lossyrob Jan 22, 2025
dd74c4f
Add comment on cosine similarity and dot prod calculation
lossyrob Jan 22, 2025
89b2a9f
Fix model_post_init
lossyrob Jan 22, 2025
79c0e32
Test include_vectors and include_total_count
lossyrob Jan 22, 2025
2b58e0c
async pull rows in search if total count not needed
lossyrob Jan 22, 2025
460839a
Fix unit test for async result fetching
lossyrob Jan 22, 2025
39a855f
Shorten import package paths
lossyrob Jan 23, 2025
24d6346
Modify distance column name conflict resolve logic
lossyrob Jan 23, 2025
3eca521
Use data_model_definition.get_field_names
lossyrob Jan 23, 2025
178588d
Inline select_fields
lossyrob Jan 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
427 changes: 400 additions & 27 deletions python/samples/getting_started/third_party/postgres-memory.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
# Limitation based on pgvector documentation https://github.com/pgvector/pgvector#what-if-i-want-to-index-vectors-with-more-than-2000-dimensions
MAX_DIMENSIONALITY = 2000

# The name of the column that returns distance value in the database.
# It is used in the similarity search query. Must not conflict with model property.
DISTANCE_COLUMN_NAME = "sk_pg_distance"
eavanvalkenburg marked this conversation as resolved.
Show resolved Hide resolved

# Environment Variables
PGHOST_ENV_VAR = "PGHOST"
PGPORT_ENV_VAR = "PGPORT"
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from psycopg.conninfo import conninfo_to_dict
from psycopg_pool import AsyncConnectionPool
from psycopg_pool.abc import ACT
from pydantic import Field, SecretStr

from semantic_kernel.connectors.memory.postgres.constants import (
Expand All @@ -14,10 +15,7 @@
PGSSL_MODE_ENV_VAR,
PGUSER_ENV_VAR,
)
from semantic_kernel.exceptions.memory_connector_exceptions import (
MemoryConnectorConnectionException,
MemoryConnectorInitializationError,
)
from semantic_kernel.exceptions.memory_connector_exceptions import MemoryConnectorConnectionException
from semantic_kernel.kernel_pydantic import KernelBaseSettings
from semantic_kernel.utils.experimental_decorator import experimental_class

Expand Down Expand Up @@ -89,30 +87,34 @@ def get_connection_args(self) -> dict[str, Any]:
if self.password:
result["password"] = self.password.get_secret_value()

# Ensure required values
if "host" not in result:
raise MemoryConnectorInitializationError("host is required. Please set PGHOST or connection_string.")
if "dbname" not in result:
raise MemoryConnectorInitializationError(
"database is required. Please set PGDATABASE or connection_string."
)
if "user" not in result:
raise MemoryConnectorInitializationError("user is required. Please set PGUSER or connection_string.")
if "password" not in result:
raise MemoryConnectorInitializationError(
"password is required. Please set PGPASSWORD or connection_string."
)

return result

async def create_connection_pool(self) -> AsyncConnectionPool:
"""Creates a connection pool based off of settings."""
async def create_connection_pool(
self, connection_class: type[ACT] | None = None, **kwargs: Any
) -> AsyncConnectionPool:
"""Creates a connection pool based off of settings.

Args:
connection_class: The connection class to use.
kwargs: Additional keyword arguments to pass to the connection class.

Returns:
The connection pool.
"""
try:
# Only pass connection_class if it specified, or else allow psycopg to use the default connection class
extra_args: dict[str, Any] = {} if connection_class is None else {"connection_class": connection_class}

pool = AsyncConnectionPool(
min_size=self.min_pool,
max_size=self.max_pool,
open=False,
kwargs=self.get_connection_args(),
# kwargs are passed to the connection class
kwargs={
**self.get_connection_args(),
**kwargs,
},
**extra_args,
)
await pool.open()
except Exception as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,20 @@
import sys
from typing import Any, TypeVar

if sys.version_info >= (3, 12):
from typing import override # pragma: no cover
else:
from typing_extensions import override # pragma: no cover

from psycopg import sql
from psycopg_pool import AsyncConnectionPool

from semantic_kernel.connectors.memory.postgres.postgres_collection import PostgresCollection
from semantic_kernel.connectors.memory.postgres.postgres_memory_store import DEFAULT_SCHEMA
from semantic_kernel.data.record_definition.vector_store_model_definition import VectorStoreRecordDefinition
from semantic_kernel.data.vector_storage.vector_store import VectorStore
from semantic_kernel.data.vector_storage.vector_store_record_collection import VectorStoreRecordCollection
from semantic_kernel.data import VectorStore, VectorStoreRecordCollection, VectorStoreRecordDefinition
from semantic_kernel.utils.experimental_decorator import experimental_class

if sys.version_info >= (3, 12):
from typing import override # pragma: no cover
else:
from typing_extensions import override # pragma: no cover


logger: logging.Logger = logging.getLogger(__name__)

TModel = TypeVar("TModel")
Expand Down
45 changes: 41 additions & 4 deletions python/semantic_kernel/connectors/memory/postgres/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def python_type_to_postgres(python_type_str: str) -> str | None:
return None


def convert_row_to_dict(row: tuple[Any, ...], fields: list[tuple[str, VectorStoreRecordField]]) -> dict[str, Any]:
def convert_row_to_dict(
row: tuple[Any, ...], fields: list[tuple[str, VectorStoreRecordField | None]]
) -> dict[str, Any]:
"""Convert a row from a PostgreSQL query to a dictionary.

Uses the field information to map the row values to the corresponding field names.
Expand All @@ -65,11 +67,12 @@ def convert_row_to_dict(row: tuple[Any, ...], fields: list[tuple[str, VectorStor
A dictionary representation of the row.
"""

def _convert(v: Any | None, field: VectorStoreRecordField) -> Any | None:
def _convert(v: Any | None, field: VectorStoreRecordField | None) -> Any | None:
if v is None:
return None
if isinstance(field, VectorStoreRecordVectorField):
# psycopg returns vector as a string
if isinstance(field, VectorStoreRecordVectorField) and isinstance(v, str):
# psycopg returns vector as a string if pgvector is not loaded.
# If pgvector is registered with the connection, no conversion is required.
return json.loads(v)
return v

Expand Down Expand Up @@ -109,6 +112,8 @@ def get_vector_index_ops_str(distance_function: DistanceFunction) -> str:
>>> get_vector_index_ops_str(DistanceFunction.COSINE)
'vector_cosine_ops'
"""
if distance_function == DistanceFunction.COSINE_DISTANCE:
return "vector_cosine_ops"
if distance_function == DistanceFunction.COSINE_SIMILARITY:
return "vector_cosine_ops"
if distance_function == DistanceFunction.DOT_PROD:
Expand All @@ -121,6 +126,38 @@ def get_vector_index_ops_str(distance_function: DistanceFunction) -> str:
raise ValueError(f"Unsupported distance function: {distance_function}")


def get_vector_distance_ops_str(distance_function: DistanceFunction) -> str:
"""Get the PostgreSQL distance operator string for a given distance function.
eavanvalkenburg marked this conversation as resolved.
Show resolved Hide resolved

Args:
distance_function: The distance function for which the operator string is needed.

Note:
For the COSINE_SIMILARITY and DOT_PROD distance functions,
there is additional query steps to retrieve the correct distance.
For dot product, take -1 * inner product, as <#> returns the negative inner product
since Postgres only supports ASC order index scans on operators
For cosine similarity, take 1 - cosine distance.

Returns:
The PostgreSQL distance operator string for the given distance function.

Raises:
ValueError: If the distance function is unsupported.
"""
if distance_function == DistanceFunction.COSINE_DISTANCE:
return "<=>"
if distance_function == DistanceFunction.COSINE_SIMILARITY:
return "<=>"
if distance_function == DistanceFunction.DOT_PROD:
return "<#>"
if distance_function == DistanceFunction.EUCLIDEAN_DISTANCE:
return "<->"
if distance_function == DistanceFunction.MANHATTAN:
return "<+>"
raise ValueError(f"Unsupported distance function: {distance_function}")


async def ensure_open(connection_pool: AsyncConnectionPool) -> AsyncConnectionPool:
"""Ensure the connection pool is open.

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.

import uuid
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Sequence
from contextlib import asynccontextmanager
from typing import Annotated, Any

Expand All @@ -11,6 +11,7 @@
from pydantic import BaseModel

from semantic_kernel.connectors.memory.postgres import PostgresSettings, PostgresStore
from semantic_kernel.connectors.memory.postgres.postgres_collection import PostgresCollection
from semantic_kernel.data import (
DistanceFunction,
IndexKind,
Expand All @@ -20,6 +21,7 @@
VectorStoreRecordVectorField,
vectorstoremodel,
)
from semantic_kernel.data.vector_search.vector_search_options import VectorSearchOptions
from semantic_kernel.exceptions.memory_connector_exceptions import (
MemoryConnectorConnectionException,
MemoryConnectorInitializationError,
Expand Down Expand Up @@ -49,13 +51,13 @@
class SimpleDataModel(BaseModel):
id: Annotated[int, VectorStoreRecordKeyField()]
embedding: Annotated[
list[float],
list[float] | None,
VectorStoreRecordVectorField(
index_kind=IndexKind.HNSW,
dimensions=3,
distance_function=DistanceFunction.COSINE_SIMILARITY,
),
]
] = None
data: Annotated[
dict[str, Any],
VectorStoreRecordDataField(has_embedding=True, embedding_property_name="embedding", property_type="JSONB"),
Expand Down Expand Up @@ -97,7 +99,9 @@ async def vector_store() -> AsyncGenerator[PostgresStore, None]:


@asynccontextmanager
async def create_simple_collection(vector_store: PostgresStore):
async def create_simple_collection(
vector_store: PostgresStore,
) -> AsyncGenerator[PostgresCollection[int, SimpleDataModel], None]:
"""Returns a collection with a unique name that is deleted after the context.

This can be moved to use a fixture with scope=function and loop_scope=session
Expand All @@ -107,6 +111,7 @@ async def create_simple_collection(vector_store: PostgresStore):
suffix = str(uuid.uuid4()).replace("-", "")[:8]
collection_id = f"test_collection_{suffix}"
collection = vector_store.get_collection(collection_id, SimpleDataModel)
assert isinstance(collection, PostgresCollection)
await collection.create_collection()
try:
yield collection
Expand Down Expand Up @@ -213,6 +218,7 @@ async def test_upsert_get_and_delete_batch(vector_store: PostgresStore):
# this should return only the two existing records.
result = await simple_collection.get_batch([1, 2, 3])
assert result is not None
assert isinstance(result, Sequence)
assert len(result) == 2
assert result[0] is not None
assert result[0].id == record1.id
Expand All @@ -226,3 +232,28 @@ async def test_upsert_get_and_delete_batch(vector_store: PostgresStore):
await simple_collection.delete_batch([1, 2])
result_after_delete = await simple_collection.get_batch([1, 2])
assert result_after_delete is None


async def test_search(vector_store: PostgresStore):
async with create_simple_collection(vector_store) as simple_collection:
records = [
SimpleDataModel(id=1, embedding=[1.0, 0.0, 0.0], data={"key": "value1"}),
SimpleDataModel(id=2, embedding=[0.8, 0.2, 0.0], data={"key": "value2"}),
SimpleDataModel(id=3, embedding=[0.6, 0.0, 0.4], data={"key": "value3"}),
SimpleDataModel(id=4, embedding=[1.0, 1.0, 0.0], data={"key": "value4"}),
SimpleDataModel(id=5, embedding=[0.0, 1.0, 1.0], data={"key": "value5"}),
SimpleDataModel(id=6, embedding=[1.0, 0.0, 1.0], data={"key": "value6"}),
]

await simple_collection.upsert_batch(records)

try:
search_results = await simple_collection.vectorized_search(
[1.0, 0.0, 0.0], options=VectorSearchOptions(top=3, include_total_count=True)
)
assert search_results is not None
assert search_results.total_count == 3
assert {result.record.id async for result in search_results.results} == {1, 2, 3}

finally:
await simple_collection.delete_batch([r.id for r in records])
Loading
Loading