Skip to content

Commit 3724e24

Browse files
authored
Python: Add vector search to Postgres connector (#10213)
### Motivation and Context Following up on #8951, this PR adds an implementation of `VectorSearchBase` to `PostgresCollection`. This implementation provides vectorized search and does not implement text search or vectorizable text search. Unit and integration tests are added, and the `python/samples/getting_started/third_party/postgres-memory.ipynb` notebook was expanded to include vector search in the example. ### Contribution Checklist - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone 😄
1 parent f32972c commit 3724e24

File tree

8 files changed

+955
-136
lines changed

8 files changed

+955
-136
lines changed

python/samples/getting_started/third_party/postgres-memory.ipynb

+400-27
Large diffs are not rendered by default.

python/semantic_kernel/connectors/memory/postgres/constants.py

+4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
# Limitation based on pgvector documentation https://github.com/pgvector/pgvector#what-if-i-want-to-index-vectors-with-more-than-2000-dimensions
66
MAX_DIMENSIONALITY = 2000
77

8+
# The name of the column that returns distance value in the database.
9+
# It is used in the similarity search query. Must not conflict with model property.
10+
DISTANCE_COLUMN_NAME = "sk_pg_distance"
11+
812
# Environment Variables
913
PGHOST_ENV_VAR = "PGHOST"
1014
PGPORT_ENV_VAR = "PGPORT"

python/semantic_kernel/connectors/memory/postgres/postgres_collection.py

+311-71
Large diffs are not rendered by default.

python/semantic_kernel/connectors/memory/postgres/postgres_settings.py

+23-21
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from psycopg.conninfo import conninfo_to_dict
66
from psycopg_pool import AsyncConnectionPool
7+
from psycopg_pool.abc import ACT
78
from pydantic import Field, SecretStr
89

910
from semantic_kernel.connectors.memory.postgres.constants import (
@@ -14,10 +15,7 @@
1415
PGSSL_MODE_ENV_VAR,
1516
PGUSER_ENV_VAR,
1617
)
17-
from semantic_kernel.exceptions.memory_connector_exceptions import (
18-
MemoryConnectorConnectionException,
19-
MemoryConnectorInitializationError,
20-
)
18+
from semantic_kernel.exceptions.memory_connector_exceptions import MemoryConnectorConnectionException
2119
from semantic_kernel.kernel_pydantic import KernelBaseSettings
2220
from semantic_kernel.utils.experimental_decorator import experimental_class
2321

@@ -89,30 +87,34 @@ def get_connection_args(self) -> dict[str, Any]:
8987
if self.password:
9088
result["password"] = self.password.get_secret_value()
9189

92-
# Ensure required values
93-
if "host" not in result:
94-
raise MemoryConnectorInitializationError("host is required. Please set PGHOST or connection_string.")
95-
if "dbname" not in result:
96-
raise MemoryConnectorInitializationError(
97-
"database is required. Please set PGDATABASE or connection_string."
98-
)
99-
if "user" not in result:
100-
raise MemoryConnectorInitializationError("user is required. Please set PGUSER or connection_string.")
101-
if "password" not in result:
102-
raise MemoryConnectorInitializationError(
103-
"password is required. Please set PGPASSWORD or connection_string."
104-
)
105-
10690
return result
10791

108-
async def create_connection_pool(self) -> AsyncConnectionPool:
109-
"""Creates a connection pool based off of settings."""
92+
async def create_connection_pool(
93+
self, connection_class: type[ACT] | None = None, **kwargs: Any
94+
) -> AsyncConnectionPool:
95+
"""Creates a connection pool based off of settings.
96+
97+
Args:
98+
connection_class: The connection class to use.
99+
kwargs: Additional keyword arguments to pass to the connection class.
100+
101+
Returns:
102+
The connection pool.
103+
"""
110104
try:
105+
# Only pass connection_class if it specified, or else allow psycopg to use the default connection class
106+
extra_args: dict[str, Any] = {} if connection_class is None else {"connection_class": connection_class}
107+
111108
pool = AsyncConnectionPool(
112109
min_size=self.min_pool,
113110
max_size=self.max_pool,
114111
open=False,
115-
kwargs=self.get_connection_args(),
112+
# kwargs are passed to the connection class
113+
kwargs={
114+
**self.get_connection_args(),
115+
**kwargs,
116+
},
117+
**extra_args,
116118
)
117119
await pool.open()
118120
except Exception as e:

python/semantic_kernel/connectors/memory/postgres/postgres_store.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,20 @@
44
import sys
55
from typing import Any, TypeVar
66

7-
if sys.version_info >= (3, 12):
8-
from typing import override # pragma: no cover
9-
else:
10-
from typing_extensions import override # pragma: no cover
11-
127
from psycopg import sql
138
from psycopg_pool import AsyncConnectionPool
149

1510
from semantic_kernel.connectors.memory.postgres.postgres_collection import PostgresCollection
1611
from semantic_kernel.connectors.memory.postgres.postgres_memory_store import DEFAULT_SCHEMA
17-
from semantic_kernel.data.record_definition.vector_store_model_definition import VectorStoreRecordDefinition
18-
from semantic_kernel.data.vector_storage.vector_store import VectorStore
19-
from semantic_kernel.data.vector_storage.vector_store_record_collection import VectorStoreRecordCollection
12+
from semantic_kernel.data import VectorStore, VectorStoreRecordCollection, VectorStoreRecordDefinition
2013
from semantic_kernel.utils.experimental_decorator import experimental_class
2114

15+
if sys.version_info >= (3, 12):
16+
from typing import override # pragma: no cover
17+
else:
18+
from typing_extensions import override # pragma: no cover
19+
20+
2221
logger: logging.Logger = logging.getLogger(__name__)
2322

2423
TModel = TypeVar("TModel")

python/semantic_kernel/connectors/memory/postgres/utils.py

+41-4
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ def python_type_to_postgres(python_type_str: str) -> str | None:
5252
return None
5353

5454

55-
def convert_row_to_dict(row: tuple[Any, ...], fields: list[tuple[str, VectorStoreRecordField]]) -> dict[str, Any]:
55+
def convert_row_to_dict(
56+
row: tuple[Any, ...], fields: list[tuple[str, VectorStoreRecordField | None]]
57+
) -> dict[str, Any]:
5658
"""Convert a row from a PostgreSQL query to a dictionary.
5759
5860
Uses the field information to map the row values to the corresponding field names.
@@ -65,11 +67,12 @@ def convert_row_to_dict(row: tuple[Any, ...], fields: list[tuple[str, VectorStor
6567
A dictionary representation of the row.
6668
"""
6769

68-
def _convert(v: Any | None, field: VectorStoreRecordField) -> Any | None:
70+
def _convert(v: Any | None, field: VectorStoreRecordField | None) -> Any | None:
6971
if v is None:
7072
return None
71-
if isinstance(field, VectorStoreRecordVectorField):
72-
# psycopg returns vector as a string
73+
if isinstance(field, VectorStoreRecordVectorField) and isinstance(v, str):
74+
# psycopg returns vector as a string if pgvector is not loaded.
75+
# If pgvector is registered with the connection, no conversion is required.
7376
return json.loads(v)
7477
return v
7578

@@ -109,6 +112,8 @@ def get_vector_index_ops_str(distance_function: DistanceFunction) -> str:
109112
>>> get_vector_index_ops_str(DistanceFunction.COSINE)
110113
'vector_cosine_ops'
111114
"""
115+
if distance_function == DistanceFunction.COSINE_DISTANCE:
116+
return "vector_cosine_ops"
112117
if distance_function == DistanceFunction.COSINE_SIMILARITY:
113118
return "vector_cosine_ops"
114119
if distance_function == DistanceFunction.DOT_PROD:
@@ -121,6 +126,38 @@ def get_vector_index_ops_str(distance_function: DistanceFunction) -> str:
121126
raise ValueError(f"Unsupported distance function: {distance_function}")
122127

123128

129+
def get_vector_distance_ops_str(distance_function: DistanceFunction) -> str:
130+
"""Get the PostgreSQL distance operator string for a given distance function.
131+
132+
Args:
133+
distance_function: The distance function for which the operator string is needed.
134+
135+
Note:
136+
For the COSINE_SIMILARITY and DOT_PROD distance functions,
137+
there is additional query steps to retrieve the correct distance.
138+
For dot product, take -1 * inner product, as <#> returns the negative inner product
139+
since Postgres only supports ASC order index scans on operators
140+
For cosine similarity, take 1 - cosine distance.
141+
142+
Returns:
143+
The PostgreSQL distance operator string for the given distance function.
144+
145+
Raises:
146+
ValueError: If the distance function is unsupported.
147+
"""
148+
if distance_function == DistanceFunction.COSINE_DISTANCE:
149+
return "<=>"
150+
if distance_function == DistanceFunction.COSINE_SIMILARITY:
151+
return "<=>"
152+
if distance_function == DistanceFunction.DOT_PROD:
153+
return "<#>"
154+
if distance_function == DistanceFunction.EUCLIDEAN_DISTANCE:
155+
return "<->"
156+
if distance_function == DistanceFunction.MANHATTAN:
157+
return "<+>"
158+
raise ValueError(f"Unsupported distance function: {distance_function}")
159+
160+
124161
async def ensure_open(connection_pool: AsyncConnectionPool) -> AsyncConnectionPool:
125162
"""Ensure the connection pool is open.
126163

python/tests/integration/memory/vector_stores/postgres/test_postgres_int.py

+35-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright (c) Microsoft. All rights reserved.
22

33
import uuid
4-
from collections.abc import AsyncGenerator
4+
from collections.abc import AsyncGenerator, Sequence
55
from contextlib import asynccontextmanager
66
from typing import Annotated, Any
77

@@ -11,6 +11,7 @@
1111
from pydantic import BaseModel
1212

1313
from semantic_kernel.connectors.memory.postgres import PostgresSettings, PostgresStore
14+
from semantic_kernel.connectors.memory.postgres.postgres_collection import PostgresCollection
1415
from semantic_kernel.data import (
1516
DistanceFunction,
1617
IndexKind,
@@ -20,6 +21,7 @@
2021
VectorStoreRecordVectorField,
2122
vectorstoremodel,
2223
)
24+
from semantic_kernel.data.vector_search.vector_search_options import VectorSearchOptions
2325
from semantic_kernel.exceptions.memory_connector_exceptions import (
2426
MemoryConnectorConnectionException,
2527
MemoryConnectorInitializationError,
@@ -49,13 +51,13 @@
4951
class SimpleDataModel(BaseModel):
5052
id: Annotated[int, VectorStoreRecordKeyField()]
5153
embedding: Annotated[
52-
list[float],
54+
list[float] | None,
5355
VectorStoreRecordVectorField(
5456
index_kind=IndexKind.HNSW,
5557
dimensions=3,
5658
distance_function=DistanceFunction.COSINE_SIMILARITY,
5759
),
58-
]
60+
] = None
5961
data: Annotated[
6062
dict[str, Any],
6163
VectorStoreRecordDataField(has_embedding=True, embedding_property_name="embedding", property_type="JSONB"),
@@ -97,7 +99,9 @@ async def vector_store() -> AsyncGenerator[PostgresStore, None]:
9799

98100

99101
@asynccontextmanager
100-
async def create_simple_collection(vector_store: PostgresStore):
102+
async def create_simple_collection(
103+
vector_store: PostgresStore,
104+
) -> AsyncGenerator[PostgresCollection[int, SimpleDataModel], None]:
101105
"""Returns a collection with a unique name that is deleted after the context.
102106
103107
This can be moved to use a fixture with scope=function and loop_scope=session
@@ -107,6 +111,7 @@ async def create_simple_collection(vector_store: PostgresStore):
107111
suffix = str(uuid.uuid4()).replace("-", "")[:8]
108112
collection_id = f"test_collection_{suffix}"
109113
collection = vector_store.get_collection(collection_id, SimpleDataModel)
114+
assert isinstance(collection, PostgresCollection)
110115
await collection.create_collection()
111116
try:
112117
yield collection
@@ -213,6 +218,7 @@ async def test_upsert_get_and_delete_batch(vector_store: PostgresStore):
213218
# this should return only the two existing records.
214219
result = await simple_collection.get_batch([1, 2, 3])
215220
assert result is not None
221+
assert isinstance(result, Sequence)
216222
assert len(result) == 2
217223
assert result[0] is not None
218224
assert result[0].id == record1.id
@@ -226,3 +232,28 @@ async def test_upsert_get_and_delete_batch(vector_store: PostgresStore):
226232
await simple_collection.delete_batch([1, 2])
227233
result_after_delete = await simple_collection.get_batch([1, 2])
228234
assert result_after_delete is None
235+
236+
237+
async def test_search(vector_store: PostgresStore):
238+
async with create_simple_collection(vector_store) as simple_collection:
239+
records = [
240+
SimpleDataModel(id=1, embedding=[1.0, 0.0, 0.0], data={"key": "value1"}),
241+
SimpleDataModel(id=2, embedding=[0.8, 0.2, 0.0], data={"key": "value2"}),
242+
SimpleDataModel(id=3, embedding=[0.6, 0.0, 0.4], data={"key": "value3"}),
243+
SimpleDataModel(id=4, embedding=[1.0, 1.0, 0.0], data={"key": "value4"}),
244+
SimpleDataModel(id=5, embedding=[0.0, 1.0, 1.0], data={"key": "value5"}),
245+
SimpleDataModel(id=6, embedding=[1.0, 0.0, 1.0], data={"key": "value6"}),
246+
]
247+
248+
await simple_collection.upsert_batch(records)
249+
250+
try:
251+
search_results = await simple_collection.vectorized_search(
252+
[1.0, 0.0, 0.0], options=VectorSearchOptions(top=3, include_total_count=True)
253+
)
254+
assert search_results is not None
255+
assert search_results.total_count == 3
256+
assert {result.record.id async for result in search_results.results} == {1, 2, 3}
257+
258+
finally:
259+
await simple_collection.delete_batch([r.id for r in records])

0 commit comments

Comments
 (0)