Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def pglite_session(pglite_engine: Engine) -> Generator[Any, None, None]:
logger.info(f"Truncating table: {table_name}")
conn.execute(
text(
f'TRUNCATE TABLE "{table_name}" RESTART IDENTITY CASCADE;'
f'TRUNCATE TABLE "{table_name}" '
"RESTART IDENTITY CASCADE;"
)
)

Expand Down
6 changes: 4 additions & 2 deletions examples/test_utils.py → examples/test_example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ def test_schema_operations(pglite_engine):
with session.connection() as conn:
result = conn.execute(
text(
"SELECT schema_name FROM information_schema.schemata WHERE schema_name = :name"
"SELECT schema_name FROM information_schema.schemata "
"WHERE schema_name = :name"
),
{"name": test_schema},
)
Expand All @@ -185,7 +186,8 @@ def test_schema_operations(pglite_engine):
with session.connection() as conn:
result = conn.execute(
text(
"SELECT schema_name FROM information_schema.schemata WHERE schema_name = :name"
"SELECT schema_name FROM information_schema.schemata "
"WHERE schema_name = :name"
),
{"name": test_schema},
)
Expand Down
11 changes: 7 additions & 4 deletions examples/testing-patterns/test_performance_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,7 @@ def test_large_query_performance(self, benchmark_engine):
print("\n📊 Large Query Performance Test")
print("=" * 50)

# Setup small dataset
batch_size = 100
batch_size = 50
users = [
BenchmarkUser(
username=f"large_user_{i}",
Expand All @@ -265,8 +264,12 @@ def test_large_query_performance(self, benchmark_engine):

start_time = time.time()
with resilient_session(benchmark_engine) as session:
session.add_all(users)
session.commit()
# Insert in smaller chunks to be more stable
chunk_size = 25
for i in range(0, len(users), chunk_size):
chunk = users[i : i + chunk_size]
session.add_all(chunk)
session.commit()

insert_duration = time.time() - start_time
print(f" 📥 Data setup: {batch_size} users in {insert_duration:.2f}s")
Expand Down
2 changes: 1 addition & 1 deletion py_pglite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
and Python test suites with support for SQLAlchemy, SQLModel, and Django.
"""

__version__ = "0.4.0"
__version__ = "0.4.1"

Check warning on line 7 in py_pglite/__init__.py

View check run for this annotation

Codecov / codecov/patch

py_pglite/__init__.py#L7

Added line #L7 was not covered by tests

# Core exports (always available)
# Database client exports (choose your preferred client)
Expand Down
48 changes: 34 additions & 14 deletions py_pglite/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,26 +129,35 @@
) -> list[tuple]:
"""Execute query using asyncpg (sync wrapper)."""
loop = self._get_event_loop()
return loop.run_until_complete(
self._async_execute_query(connection, query, params)
)
try:
return loop.run_until_complete(

Check warning on line 133 in py_pglite/clients.py

View check run for this annotation

Codecov / codecov/patch

py_pglite/clients.py#L132-L133

Added lines #L132 - L133 were not covered by tests
self._async_execute_query(connection, query, params)
)
except Exception as e:

Check warning on line 136 in py_pglite/clients.py

View check run for this annotation

Codecov / codecov/patch

py_pglite/clients.py#L136

Added line #L136 was not covered by tests
# Ensure we don't leave any coroutines hanging
logger.warning(f"AsyncpgClient execute_query failed: {e}")
raise

Check warning on line 139 in py_pglite/clients.py

View check run for this annotation

Codecov / codecov/patch

py_pglite/clients.py#L138-L139

Added lines #L138 - L139 were not covered by tests

async def _async_execute_query(
self, connection: Any, query: str, params: Any = None
) -> list[tuple]:
"""Execute query using asyncpg (async)."""
if params:
if isinstance(params, list | tuple) and len(params) == 1:
# Single parameter
result = await connection.fetch(query, params[0])
try:
if params:
if isinstance(params, list | tuple) and len(params) == 1:

Check warning on line 147 in py_pglite/clients.py

View check run for this annotation

Codecov / codecov/patch

py_pglite/clients.py#L145-L147

Added lines #L145 - L147 were not covered by tests
# Single parameter
result = await connection.fetch(query, params[0])

Check warning on line 149 in py_pglite/clients.py

View check run for this annotation

Codecov / codecov/patch

py_pglite/clients.py#L149

Added line #L149 was not covered by tests
else:
# Multiple parameters
result = await connection.fetch(query, *params)

Check warning on line 152 in py_pglite/clients.py

View check run for this annotation

Codecov / codecov/patch

py_pglite/clients.py#L152

Added line #L152 was not covered by tests
else:
# Multiple parameters
result = await connection.fetch(query, *params)
else:
result = await connection.fetch(query)
result = await connection.fetch(query)

Check warning on line 154 in py_pglite/clients.py

View check run for this annotation

Codecov / codecov/patch

py_pglite/clients.py#L154

Added line #L154 was not covered by tests

# Convert asyncpg Records to tuples
return [tuple(row) for row in result]
# Convert asyncpg Records to tuples
return [tuple(row) for row in result]
except Exception as e:
logger.warning(f"AsyncpgClient async query execution failed: {e}")
raise

Check warning on line 160 in py_pglite/clients.py

View check run for this annotation

Codecov / codecov/patch

py_pglite/clients.py#L157-L160

Added lines #L157 - L160 were not covered by tests

def test_connection(self, connection_string: str) -> bool:
"""Test asyncpg connection."""
Expand Down Expand Up @@ -185,7 +194,18 @@
def _get_event_loop(self):
"""Get or create event loop."""
try:
return self._asyncio.get_event_loop()
# Try to get the current event loop
loop = self._asyncio.get_event_loop()
# Check if loop is running - if so, we need a new thread
if loop.is_running():
# If we're in a running loop (like in pytest), we can't use
# run_until_complete
# This is a potential source of the warning - let's handle it better
logger.warning(
"AsyncpgClient: Event loop is already running. "
"Consider using psycopg client for synchronous usage."
)
return loop
except RuntimeError:
# No event loop in current thread, create a new one
loop = self._asyncio.new_event_loop()
Expand Down
15 changes: 12 additions & 3 deletions py_pglite/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import os
import tempfile
import uuid

Check warning on line 6 in py_pglite/config.py

View check run for this annotation

Codecov / codecov/patch

py_pglite/config.py#L6

Added line #L6 was not covered by tests
from dataclasses import dataclass, field
from pathlib import Path

Expand All @@ -11,7 +12,9 @@

def _get_secure_socket_path() -> str:
"""Generate a secure socket path in user's temp directory."""
temp_dir = Path(tempfile.gettempdir()) / f"py-pglite-{os.getpid()}"
# Use both PID and UUID to ensure uniqueness
unique_id = f"{os.getpid()}-{uuid.uuid4().hex[:8]}"
temp_dir = Path(tempfile.gettempdir()) / f"py-pglite-{unique_id}"
temp_dir.mkdir(mode=0o700, exist_ok=True) # Restrict to user only
# Use PostgreSQL's standard socket naming convention
return str(temp_dir / ".s.PGSQL.5432")
Expand Down Expand Up @@ -69,8 +72,8 @@
return int(level_value)

def get_connection_string(self) -> str:
"""Get PostgreSQL connection string for PGlite."""
# For psycopg with Unix domain sockets, we need to specify the directory
"""Get PostgreSQL connection string for SQLAlchemy usage."""
# For SQLAlchemy with Unix domain sockets, we need to specify the directory
# and use the standard PostgreSQL socket naming convention
socket_dir = str(Path(self.socket_path).parent)

Expand All @@ -81,6 +84,12 @@

return connection_string

def get_psycopg_uri(self) -> str:

Check warning on line 87 in py_pglite/config.py

View check run for this annotation

Codecov / codecov/patch

py_pglite/config.py#L87

Added line #L87 was not covered by tests
"""Get PostgreSQL URI for direct psycopg usage."""
socket_dir = str(Path(self.socket_path).parent)
# Use standard PostgreSQL URI format for psycopg
return f"postgresql://postgres:postgres@/postgres?host={socket_dir}"

def get_dsn(self) -> str:
"""Get PostgreSQL DSN connection string for direct psycopg usage."""
socket_dir = str(Path(self.socket_path).parent)
Expand Down
48 changes: 40 additions & 8 deletions py_pglite/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,16 +188,24 @@
self.logger.warning(f"Failed to clean up socket: {e}")

def _kill_existing_processes(self) -> None:
"""Kill any existing PGlite processes."""
"""Kill any existing PGlite processes that might conflict with this socket."""
try:
for proc in psutil.process_iter(["pid", "name", "cmdline"]):
my_socket_dir = str(Path(self.config.socket_path).parent)
for proc in psutil.process_iter(["pid", "name", "cmdline", "cwd"]):
if proc.info["cmdline"] and any(
"pglite_manager.js" in cmd for cmd in proc.info["cmdline"]
):
pid = proc.info["pid"]
self.logger.info(f"Killing existing PGlite process: {pid}")
proc.kill()
proc.wait(timeout=5)
# Only kill processes in the same socket directory to avoid killing other instances
try:
proc_cwd = proc.info.get("cwd", "")
if my_socket_dir in proc_cwd or proc_cwd in my_socket_dir:
pid = proc.info["pid"]
self.logger.info(f"Killing existing PGlite process: {pid}")
proc.kill()
proc.wait(timeout=5)
except (psutil.NoSuchProcess, psutil.AccessDenied):

Check warning on line 206 in py_pglite/manager.py

View check run for this annotation

Codecov / codecov/patch

py_pglite/manager.py#L202-L206

Added lines #L202 - L206 were not covered by tests
# Process already gone or can't access it
continue

Check warning on line 208 in py_pglite/manager.py

View check run for this annotation

Codecov / codecov/patch

py_pglite/manager.py#L208

Added line #L208 was not covered by tests
except Exception as e:
self.logger.warning(f"Error killing existing PGlite processes: {e}")

Expand Down Expand Up @@ -400,11 +408,12 @@
Returns:
True if database becomes ready, False otherwise
"""
from .utils import test_connection
from .utils import check_connection

for attempt in range(max_retries):
try:
if test_connection(self.config.get_connection_string()):
# Use DSN format for direct psycopg connection testing
if check_connection(self.config.get_dsn()):
self.logger.info(f"Database ready after {attempt + 1} attempts")
time.sleep(0.2) # Small stability delay
return True
Expand Down Expand Up @@ -436,3 +445,26 @@
True if database becomes ready, False otherwise
"""
return self.wait_for_ready_basic(max_retries=max_retries, delay=delay)

def restart(self) -> None:

Check warning on line 449 in py_pglite/manager.py

View check run for this annotation

Codecov / codecov/patch

py_pglite/manager.py#L449

Added line #L449 was not covered by tests
"""Restart the PGlite server.

Stops the current server if running and starts a new one.
"""
if self.is_running():
self.stop()
self.start()

def get_psycopg_uri(self) -> str:

Check warning on line 458 in py_pglite/manager.py

View check run for this annotation

Codecov / codecov/patch

py_pglite/manager.py#L458

Added line #L458 was not covered by tests
"""Get the database URI for psycopg usage.

Returns:
PostgreSQL URI string compatible with psycopg

Raises:
RuntimeError: If PGlite server is not running
"""
if not self.is_running():
raise RuntimeError("PGlite server is not running. Call start() first.")

Check warning on line 468 in py_pglite/manager.py

View check run for this annotation

Codecov / codecov/patch

py_pglite/manager.py#L468

Added line #L468 was not covered by tests

return self.config.get_psycopg_uri()
8 changes: 6 additions & 2 deletions py_pglite/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@
return client.connect(connection_string)


def test_connection(
def check_connection(

Check warning on line 31 in py_pglite/utils.py

View check run for this annotation

Codecov / codecov/patch

py_pglite/utils.py#L31

Added line #L31 was not covered by tests
connection_string: str, client: DatabaseClient | None = None
) -> bool:
"""Test if database connection is working.

Args:
connection_string: PostgreSQL connection string
connection_string: PostgreSQL connection string (DSN format preferred)
client: Database client to use (defaults to auto-detected)

Returns:
Expand All @@ -45,6 +45,10 @@
return client.test_connection(connection_string)


# Backward compatibility alias
test_connection = check_connection

Check warning on line 49 in py_pglite/utils.py

View check run for this annotation

Codecov / codecov/patch

py_pglite/utils.py#L49

Added line #L49 was not covered by tests


def get_database_version(
connection_string: str, client: DatabaseClient | None = None
) -> str | None:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dev = [
"pytest>=7.0.0",
"pytest-asyncio>=0.21.0",
"pytest-cov>=6.1.1",
"pytest-mock>=3.0.0",
"mypy>=1.16.0",
"ruff>=0.11.12",
"build>=1.2.2.post1",
Expand Down
26 changes: 25 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,28 @@ addopts = --tb=short -v --strict-markers
# pytest -m performance # Performance benchmarks
# pytest -m integration # Integration tests
#
# 🎨 Beautiful developer experience - just like Vite!
# 🎨 Beautiful developer experience - just like Vite!

[tool:pytest]
minversion = 6.0
addopts = -ra -q --strict-markers
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
# Exclude utility functions from test discovery
collect_ignore_glob = py_pglite/utils.py
markers =
sqlalchemy: marks tests as requiring SQLAlchemy (deselect with '-m "not sqlalchemy"')
django: marks tests as requiring Django (deselect with '-m "not django"')
extensions: marks tests as requiring extension dependencies like pgvector
slow: marks tests as slow (deselect with '-m "not slow"')
integration: marks tests as integration tests
unit: marks tests as unit tests
stress: marks tests as stress/load tests
filterwarnings =
ignore::DeprecationWarning
ignore::PendingDeprecationWarning
# Only suppress specific AsyncpgClient warnings during test execution
# These occur because AsyncpgClient uses sync-over-async pattern with mocking
ignore:coroutine 'AsyncpgClient\._async_execute_query' was never awaited:RuntimeWarning
8 changes: 3 additions & 5 deletions scripts/dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,10 @@ def lint_check(self) -> bool:
print("=" * 50)

success = True
if not self.run_command(
"Ruff linting", self.ruff_cmd + ["check", "py_pglite/"]
):
if not self.run_command("Ruff linting", self.ruff_cmd + ["check", "."]):
success = False
if not self.run_command(
"Ruff formatting", self.ruff_cmd + ["format", "--check", "py_pglite/"]
"Ruff formatting", self.ruff_cmd + ["format", "--check", "."]
):
success = False
if not self.run_command("MyPy type checking", self.mypy_cmd + ["py_pglite/"]):
Expand Down Expand Up @@ -128,7 +126,7 @@ def test_examples(self) -> bool:
"pytest",
"examples/test_basic.py",
"examples/test_fastapi_auth_example.py",
"examples/test_utils.py",
"examples/test_example_utils.py",
"-v",
],
):
Expand Down
15 changes: 9 additions & 6 deletions tests/test_advanced.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Advanced example showing manual PGlite management and custom configuration."""

from typing import TYPE_CHECKING

import pytest
from sqlalchemy import text
from sqlmodel import Field, Session, SQLModel, select
from typing import TYPE_CHECKING

from py_pglite import PGliteConfig
from py_pglite.sqlalchemy import SQLAlchemyPGliteManager
Expand Down Expand Up @@ -59,7 +60,8 @@ def test_manual_lifecycle_management():
manager.start()
assert manager.is_running()

# Get engine and use it (readiness is checked in fixture, no need to check again)
# Get engine and use it (readiness is checked in fixture, no need to check
# again)
engine = manager.get_engine(echo=True) # Enable SQL logging
SQLModel.metadata.create_all(engine)

Expand Down Expand Up @@ -94,9 +96,9 @@ def test_manual_lifecycle_management():
with session.connection() as conn:
result = conn.execute(
text("""
SELECT p.name, o.quantity, o.total
FROM product p
JOIN "order" o ON p.id = o.product_id
SELECT p.name, o.quantity, o.total
FROM product p
JOIN "order" o ON p.id = o.product_id
WHERE p.category = :category
"""),
{"category": "Electronics"},
Expand Down Expand Up @@ -180,7 +182,8 @@ def test_multiple_sessions():
all_products = final_session.exec(select(Product)).all()
assert len(all_products) == 4 # 1 original + 3 new
print(
f"All sessions completed successfully, total products: {len(all_products)}"
f"All sessions completed successfully, total products: "
f"{len(all_products)}"
)
finally:
final_session.close()
Expand Down
Loading