Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "uv_build"

[project]
name = "py-pglite"
version = "0.5.1"
version = "0.5.2"
description = "Python testing library for PGlite - in-memory PostgreSQL for tests"
readme = "README.md"
license = "Apache-2.0"
Expand Down
7 changes: 4 additions & 3 deletions src/py_pglite/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dataclasses import dataclass
from dataclasses import field
from pathlib import Path
from typing import Literal

from py_pglite.extensions import SUPPORTED_EXTENSIONS

Expand Down Expand Up @@ -86,19 +87,19 @@ def log_level_int(self) -> int:
level_value = getattr(logging, self.log_level)
return int(level_value)

def get_connection_string(self) -> str:
def get_connection_string(self, driver: Literal["psycopg", "psycopg2"] = "psycopg") -> str:
"""Get PostgreSQL connection string for SQLAlchemy usage."""
if self.use_tcp:
# TCP connection string
return f"postgresql+psycopg://postgres:postgres@{self.tcp_host}:{self.tcp_port}/postgres?sslmode=disable"
return f"postgresql+{driver}://postgres:postgres@{self.tcp_host}:{self.tcp_port}/postgres?sslmode=disable"

# For SQLAlchemy with Unix domain sockets, we need to specify the directory
# and use the standard PostgreSQL socket naming convention
socket_dir = str(Path(self.socket_path).parent)

# Use the socket directory as host - psycopg will look for .s.PGSQL.5432
connection_string = (
f"postgresql+psycopg://postgres:postgres@/postgres?host={socket_dir}"
f"postgresql+{driver}://postgres:postgres@/postgres?host={socket_dir}"
)

return connection_string
Expand Down
7 changes: 4 additions & 3 deletions src/py_pglite/sqlalchemy/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import time

from typing import Any
from typing import Any, Literal

from py_pglite.manager import PGliteManager

Expand All @@ -22,7 +22,7 @@ def __enter__(self) -> "SQLAlchemyPGliteManager":
super().__enter__()
return self

def get_engine(self, **engine_kwargs: Any) -> Any:
def get_engine(self, driver: Literal["psycopg", "psycopg2"] = "psycopg", **engine_kwargs: Any) -> Any:
"""Get SQLAlchemy engine connected to PGlite.

NOTE: This method requires SQLAlchemy to be installed.
Expand All @@ -33,6 +33,7 @@ def get_engine(self, **engine_kwargs: Any) -> Any:
architecture ensures all database operations use the same connection.

Args:
driver: Which driver to use for connecting to the Postgres database. Defaults to 'psycopg'.
**engine_kwargs: Additional arguments for create_engine

Returns:
Expand Down Expand Up @@ -96,7 +97,7 @@ def get_engine(self, **engine_kwargs: Any) -> Any:

# Create and store the shared engine
self._shared_engine = create_engine(
self.config.get_connection_string(), **final_kwargs
self.config.get_connection_string(driver), **final_kwargs
)
return self._shared_engine

Expand Down