Skip to content

Commit 48dc445

Browse files
author
Benjamin PILIA
committed
Refacto integration tests
1 parent a738b67 commit 48dc445

4 files changed

Lines changed: 125 additions & 160 deletions

File tree

api/infrastructure/fastapi/endpoints/admin/providers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,11 @@ async def create_provider(
8484
raise ProviderNotReachableHTTPException(name)
8585
case ProviderAlreadyExistsError(model_name, url, router_id):
8686
raise ProviderAlreadyExistsHTTPException(model_name, url, router_id)
87-
case InconsistentModelMaxContextLengthError(actual_max_context_length, expected_max_context_length, router_name):
87+
case InconsistentModelMaxContextLengthError(expected_max_context_length, actual_max_context_length, router_name):
8888
raise InconsistentModelMaxContextLengthHTTPException(
8989
input_max_context_length=actual_max_context_length, model_max_context_length=expected_max_context_length, model_name=router_name
9090
)
91-
case InconsistentModelVectorSizeError(actual_vector_size, expected_vector_size, router_name):
91+
case InconsistentModelVectorSizeError(expected_vector_size, actual_vector_size, router_name):
9292
raise InconsistentModelVectorSizeHTTPException(actual_vector_size, expected_vector_size, router_name)
9393
case RouterNotFoundError(router_id):
9494
raise RouterNotFoundHTTPException(router_id)

api/infrastructure/postgres/_postgresproviderrepository.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,8 @@ async def create_provider(
5252
)
5353
.returning(ProviderTable)
5454
)
55-
async with self.postgres_session.begin_nested():
56-
result = await self.postgres_session.execute(query)
57-
row = result.scalar_one()
58-
55+
result = await self.postgres_session.execute(query)
56+
row = result.scalar_one()
5957
return Provider(
6058
router_id=row.router_id,
6159
user_id=row.user_id,

api/tests/integration/conftest.py

Lines changed: 63 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from httpx import ASGITransport, AsyncClient
55
import pytest
66
import pytest_asyncio
7-
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
7+
from sqlalchemy import event
8+
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
89
from sqlalchemy.pool import NullPool
910

1011
from api.app import create_app
@@ -67,11 +68,6 @@ async def test_engine():
6768
await engine.dispose()
6869

6970

70-
@pytest_asyncio.fixture(scope="session")
71-
async def test_session_factory(test_engine):
72-
return async_sessionmaker(test_engine, class_=AsyncSession, expire_on_commit=False)
73-
74-
7571
def _all_sql_factories():
7672
result = []
7773
stack = list(factories.BaseSQLFactory.__subclasses__())
@@ -82,47 +78,64 @@ def _all_sql_factories():
8278
return result
8379

8480

85-
@pytest_asyncio.fixture(scope="function")
86-
async def db_session(test_session_factory) -> AsyncGenerator[AsyncSession]:
87-
async with test_session_factory() as session:
88-
all_sql_factories = factories.BaseSQLFactory.__subclasses__()
89-
session.expire_on_commit = False
90-
try:
91-
async with session.begin_nested():
92-
for factory in all_sql_factories:
93-
factory._meta.sqlalchemy_session = session
94-
yield session
95-
finally:
96-
if session.in_transaction():
97-
await session.rollback()
98-
await session.close()
99-
81+
# @pytest_asyncio.fixture(scope="session")
82+
# async def test_session_factory(test_engine):
83+
# return async_sessionmaker(test_engine, class_=AsyncSession, expire_on_commit=False)
10084

101-
# @pytest_asyncio.fixture(scope="function")
102-
# async def db_session(test_engine) -> AsyncGenerator[AsyncSession]:
103-
# async with test_engine.connect() as connection:
104-
# transaction = await connection.begin()
105-
#
106-
# session = AsyncSession(bind=connection, expire_on_commit=False)
107-
# await session.begin_nested()
108-
#
109-
# all_sql_factories = _all_sql_factories()
110-
# for factory in all_sql_factories:
111-
# factory._meta.sqlalchemy_session = session
112-
#
113-
# @event.listens_for(session.sync_session, "after_transaction_end")
114-
# def restart_savepoint(sess, trans):
115-
# if trans.nested and not trans._parent.nested:
116-
# sess.begin_nested()
11785
#
86+
# @pytest_asyncio.fixture(scope="function")
87+
# async def db_session(test_session_factory) -> AsyncGenerator[AsyncSession]:
88+
# async with test_session_factory() as session:
89+
# all_sql_factories = factories.BaseSQLFactory.__subclasses__()
90+
# session.expire_on_commit = False
11891
# try:
119-
# yield session
92+
# async with session.begin_nested():
93+
# for factory in all_sql_factories:
94+
# factory._meta.sqlalchemy_session = session
95+
# yield session
12096
# finally:
121-
# event.remove(session.sync_session, "after_transaction_end", restart_savepoint)
122-
# for factory in all_sql_factories:
123-
# factory._meta.sqlalchemy_session = None
97+
# if session.in_transaction():
98+
# await session.rollback()
12499
# await session.close()
125-
# await transaction.rollback()
100+
101+
102+
def pytest_addoption(parser):
103+
parser.addoption(
104+
"--commit-db",
105+
action="store_true",
106+
default=False,
107+
help="Commit DB changes after each test (for debugging with psql).",
108+
)
109+
110+
111+
@pytest_asyncio.fixture(scope="function")
112+
async def db_session(test_engine, request) -> AsyncGenerator[AsyncSession]:
113+
async with test_engine.connect() as connection:
114+
transaction = await connection.begin()
115+
116+
session = AsyncSession(bind=connection, expire_on_commit=False)
117+
await session.begin_nested()
118+
119+
all_sql_factories = _all_sql_factories()
120+
for factory in all_sql_factories:
121+
factory._meta.sqlalchemy_session = session
122+
123+
@event.listens_for(session.sync_session, "after_transaction_end")
124+
def restart_savepoint(sess, trans):
125+
if trans.nested and not trans._parent.nested:
126+
sess.begin_nested()
127+
128+
try:
129+
yield session
130+
finally:
131+
event.remove(session.sync_session, "after_transaction_end", restart_savepoint)
132+
for factory in all_sql_factories:
133+
factory._meta.sqlalchemy_session = None
134+
await session.close()
135+
if request.config.getoption("--commit-db"):
136+
await transaction.commit()
137+
else:
138+
await transaction.rollback()
126139

127140

128141
@pytest_asyncio.fixture(scope="session")
@@ -137,7 +150,7 @@ def model_registry():
137150

138151

139152
@pytest_asyncio.fixture(scope="function")
140-
async def client(db_session, model_registry, test_configuration) -> AsyncGenerator[AsyncClient, None]:
153+
async def app(db_session, model_registry, test_configuration):
141154
app = create_app(test_configuration, skip_lifespan=True)
142155

143156
async def override_get_postgres_session():
@@ -155,7 +168,12 @@ async def override_get_postgres_session():
155168
app.dependency_overrides[get_model_registry] = lambda: model_registry
156169

157170
try:
158-
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac:
159-
yield ac
171+
yield app
160172
finally:
161173
app.dependency_overrides.clear()
174+
175+
176+
@pytest_asyncio.fixture(scope="function")
177+
async def client(app) -> AsyncGenerator[AsyncClient, None]:
178+
async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as ac:
179+
yield ac
Lines changed: 58 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,26 @@
1+
from unittest.mock import AsyncMock
2+
13
import httpx
24
from httpx import AsyncClient
35
import pytest
46
import pytest_asyncio
57
import respx
68

9+
from api.dependencies import create_provider_use_case_factory
10+
from api.domain.model.errors import InconsistentModelMaxContextLengthError, InconsistentModelVectorSizeError
11+
from api.domain.provider.errors import InvalidProviderTypeError, ProviderAlreadyExistsError, ProviderNotReachableError
12+
from api.domain.router.errors import RouterNotFoundError
713
from api.schemas.models import ModelType
814
from api.tests.helpers import create_token
9-
from api.tests.integration.factories import ProviderSQLFactory, RouterSQLFactory, UserSQLFactory
15+
from api.tests.integration.factories import RouterSQLFactory, UserSQLFactory
1016
from api.utils.variables import EndpointRoute
1117

1218
URL = f"/v1{EndpointRoute.ADMIN_PROVIDERS}"
1319

1420
DEFAULT_PROVIDER_URL = "http://my-test-provider/"
1521

1622

17-
def _valid_body(router_id=1, **overrides) -> dict:
23+
def _valid_body(router_id: int, **overrides) -> dict:
1824
"""Return a minimal valid provider creation body, with optional overrides."""
1925
body = {
2026
"router": router_id,
@@ -48,20 +54,18 @@ def _mock_provider_reachable(respx_mock, base_url=DEFAULT_PROVIDER_URL, max_cont
4854
)
4955

5056

51-
def _mock_provider_unreachable(respx_mock, base_url=DEFAULT_PROVIDER_URL):
52-
"""Mock a provider that cannot be reached."""
53-
base_url = base_url.rstrip("/")
54-
respx_mock.get(f"{base_url}/v1/models").mock(side_effect=httpx.ConnectError("connection refused"))
55-
respx_mock.post(f"{base_url}/v1/embeddings").mock(side_effect=httpx.ConnectError("connection refused"))
56-
57-
5857
@pytest.mark.asyncio(loop_scope="session")
5958
class TestCreateProvider:
6059
@pytest_asyncio.fixture(autouse=True)
6160
async def setup(self, db_session):
6261
self.admin_user = UserSQLFactory(admin_user=True)
6362
self.token = await create_token(db_session, name="admin_token", user=self.admin_user)
6463

64+
@pytest_asyncio.fixture(autouse=True)
65+
async def cleanup_overrides(self, app):
66+
yield
67+
app.dependency_overrides.pop(create_provider_use_case_factory, None)
68+
6569
@respx.mock
6670
async def test_happy_path(self, client: AsyncClient, db_session):
6771
router = RouterSQLFactory(user=self.admin_user, type=ModelType.TEXT_GENERATION)
@@ -76,111 +80,56 @@ async def test_happy_path(self, client: AsyncClient, db_session):
7680
assert response.status_code == 201, response.text
7781
assert isinstance(response.json()["id"], int)
7882

79-
@respx.mock
80-
async def test_incompatible_provider_type(self, client: AsyncClient, db_session):
81-
router = RouterSQLFactory(user=self.admin_user, type=ModelType.TEXT_GENERATION)
82-
await db_session.flush()
83-
_mock_provider_reachable(respx, base_url="https://tei.example.com")
84-
85-
response = await client.post(
86-
url=URL,
87-
headers={"Authorization": f"Bearer {self.token.token}"},
88-
json=_valid_body(router.id, type="tei", url="https://tei.example.com/"),
89-
)
90-
91-
assert response.status_code == 400
92-
assert response.json().get("detail") == "Invalid model provider type tei for text-generation router."
93-
94-
@respx.mock
95-
async def test_provider_not_reachable(self, client: AsyncClient, db_session):
96-
router = RouterSQLFactory(user=self.admin_user, type=ModelType.TEXT_GENERATION)
97-
await db_session.flush()
98-
_mock_provider_unreachable(respx)
99-
100-
response = await client.post(
101-
url=URL,
102-
headers={"Authorization": f"Bearer {self.token.token}"},
103-
json=_valid_body(router.id),
104-
)
105-
106-
assert response.status_code == 424
107-
assert response.json().get("detail") == "Model provider my-model not reachable."
108-
109-
@respx.mock
110-
async def test_provider_already_exists(self, client: AsyncClient, db_session):
111-
router = RouterSQLFactory(user=self.admin_user, type=ModelType.TEXT_GENERATION)
112-
ProviderSQLFactory(
113-
router=router,
114-
user=self.admin_user,
115-
url=DEFAULT_PROVIDER_URL,
116-
model_name="my-model",
117-
max_context_length=4096,
118-
vector_size=None,
119-
)
120-
await db_session.flush()
121-
_mock_provider_reachable(respx)
122-
123-
response = await client.post(
124-
url=URL,
125-
headers={"Authorization": f"Bearer {self.token.token}"},
126-
json=_valid_body(router.id),
127-
)
128-
assert response.status_code == 409
129-
assert response.json().get("detail") == "Model provider my-model for url http://my-test-provider/ already exists for router 4."
130-
131-
@respx.mock
132-
async def test_provider_mismatch_max_context_length(self, client: AsyncClient, db_session):
133-
router = RouterSQLFactory(user=self.admin_user, type=ModelType.TEXT_EMBEDDINGS_INFERENCE, name="test_router")
134-
ProviderSQLFactory(
135-
router=router,
136-
user=self.admin_user,
137-
url="https://albert.api.etalab.gouv.fr/",
138-
model_name="my-model",
139-
max_context_length=4096,
140-
vector_size=1234,
141-
)
142-
await db_session.flush()
143-
_mock_provider_reachable(respx, max_context_length=1234, vector_size=1234)
144-
145-
response = await client.post(
146-
url=URL,
147-
headers={"Authorization": f"Bearer {self.token.token}"},
148-
json=_valid_body(router.id),
149-
)
150-
151-
assert response.status_code == 403
152-
assert response.json().get("detail") == "Inconsistent max context length for test_router. Expected: 1234. Actual: 4096"
153-
154-
@respx.mock
155-
async def test_provider_mismatch_vector_size(self, client: AsyncClient, db_session):
156-
router = RouterSQLFactory(user=self.admin_user, type=ModelType.TEXT_GENERATION, name="test_router")
157-
ProviderSQLFactory(
158-
router=router,
159-
user=self.admin_user,
160-
url="https://albert.api.etalab.gouv.fr/",
161-
model_name="my-model",
162-
max_context_length=4096,
163-
vector_size=1234,
164-
)
165-
await db_session.flush()
166-
_mock_provider_reachable(respx, max_context_length=1234, vector_size=1234)
83+
@pytest.mark.parametrize(
84+
"use_case_result,expected_status,expected_detail",
85+
[
86+
(RouterNotFoundError(router_id=1), 404, "Model router 1 not found."),
87+
(
88+
InvalidProviderTypeError(provider_type="tei", router_type="text-generation"),
89+
400,
90+
"Invalid model provider type tei for text-generation router.",
91+
),
92+
(ProviderNotReachableError(model_name="my-model"), 424, "Model provider my-model not reachable."),
93+
(
94+
ProviderAlreadyExistsError(model_name="my-model", url=DEFAULT_PROVIDER_URL, router_id=1),
95+
409,
96+
f"Model provider my-model for url {DEFAULT_PROVIDER_URL} already exists for router 1.",
97+
),
98+
(
99+
InconsistentModelMaxContextLengthError(expected_max_context_length=4096, actual_max_context_length=2048, router_name="my-router"),
100+
403,
101+
"Inconsistent max context length for my-router. Expected: 4096. Actual: 2048",
102+
),
103+
(
104+
InconsistentModelVectorSizeError(expected_vector_size=768, actual_vector_size=384, router_name="my-router"),
105+
403,
106+
"Inconsistent vector size for my-router. Expected: 768. Actual: 384",
107+
),
108+
],
109+
)
110+
async def test_error_maps_to_correct_http_status(self, client: AsyncClient, app, use_case_result, expected_status, expected_detail):
111+
mock_use_case = AsyncMock()
112+
mock_use_case.execute.return_value = use_case_result
113+
app.dependency_overrides[create_provider_use_case_factory] = lambda: mock_use_case
167114

168115
response = await client.post(
169116
url=URL,
170117
headers={"Authorization": f"Bearer {self.token.token}"},
171-
json=_valid_body(router.id),
118+
json=_valid_body(router_id=1),
172119
)
173120

174-
assert response.status_code == 403
175-
assert response.json().get("detail") == "Inconsistent vector size for test_router. Expected: None. Actual: 1234"
121+
assert response.status_code == expected_status
122+
assert response.json().get("detail") == expected_detail
176123

177-
@respx.mock
178-
async def test_router_not_found(self, client: AsyncClient, db_session):
179-
response = await client.post(
180-
url=URL,
181-
headers={"Authorization": f"Bearer {self.token.token}"},
182-
json=_valid_body(999999),
183-
)
124+
@pytest.mark.parametrize(
125+
"headers,expected_status,expected_detail",
126+
[
127+
({}, 401, "Not authenticated"),
128+
({"Authorization": "Bearer invalid-token"}, 403, "Invalid API key."),
129+
],
130+
)
131+
async def test_auth(self, client: AsyncClient, headers, expected_status, expected_detail):
132+
response = await client.post(url=URL, headers=headers, json=_valid_body(router_id=1))
184133

185-
assert response.status_code == 404
186-
assert response.json().get("detail") == "Model router 999999 not found."
134+
assert response.status_code == expected_status
135+
assert response.json().get("detail") == expected_detail

0 commit comments

Comments
 (0)