1+ from unittest .mock import AsyncMock
2+
13import httpx
24from httpx import AsyncClient
35import pytest
46import pytest_asyncio
57import respx
68
9+ from api .dependencies import create_provider_use_case_factory
10+ from api .domain .model .errors import InconsistentModelMaxContextLengthError , InconsistentModelVectorSizeError
11+ from api .domain .provider .errors import InvalidProviderTypeError , ProviderAlreadyExistsError , ProviderNotReachableError
12+ from api .domain .router .errors import RouterNotFoundError
713from api .schemas .models import ModelType
814from api .tests .helpers import create_token
9- from api .tests .integration .factories import ProviderSQLFactory , RouterSQLFactory , UserSQLFactory
15+ from api .tests .integration .factories import RouterSQLFactory , UserSQLFactory
1016from api .utils .variables import EndpointRoute
1117
1218URL = f"/v1{ EndpointRoute .ADMIN_PROVIDERS } "
1319
1420DEFAULT_PROVIDER_URL = "http://my-test-provider/"
1521
1622
17- def _valid_body (router_id = 1 , ** overrides ) -> dict :
23+ def _valid_body (router_id : int , ** overrides ) -> dict :
1824 """Return a minimal valid provider creation body, with optional overrides."""
1925 body = {
2026 "router" : router_id ,
@@ -48,20 +54,18 @@ def _mock_provider_reachable(respx_mock, base_url=DEFAULT_PROVIDER_URL, max_cont
4854 )
4955
5056
51- def _mock_provider_unreachable (respx_mock , base_url = DEFAULT_PROVIDER_URL ):
52- """Mock a provider that cannot be reached."""
53- base_url = base_url .rstrip ("/" )
54- respx_mock .get (f"{ base_url } /v1/models" ).mock (side_effect = httpx .ConnectError ("connection refused" ))
55- respx_mock .post (f"{ base_url } /v1/embeddings" ).mock (side_effect = httpx .ConnectError ("connection refused" ))
56-
57-
5857@pytest .mark .asyncio (loop_scope = "session" )
5958class TestCreateProvider :
6059 @pytest_asyncio .fixture (autouse = True )
6160 async def setup (self , db_session ):
6261 self .admin_user = UserSQLFactory (admin_user = True )
6362 self .token = await create_token (db_session , name = "admin_token" , user = self .admin_user )
6463
64+ @pytest_asyncio .fixture (autouse = True )
65+ async def cleanup_overrides (self , app ):
66+ yield
67+ app .dependency_overrides .pop (create_provider_use_case_factory , None )
68+
6569 @respx .mock
6670 async def test_happy_path (self , client : AsyncClient , db_session ):
6771 router = RouterSQLFactory (user = self .admin_user , type = ModelType .TEXT_GENERATION )
@@ -76,111 +80,56 @@ async def test_happy_path(self, client: AsyncClient, db_session):
7680 assert response .status_code == 201 , response .text
7781 assert isinstance (response .json ()["id" ], int )
7882
79- @respx .mock
80- async def test_incompatible_provider_type (self , client : AsyncClient , db_session ):
81- router = RouterSQLFactory (user = self .admin_user , type = ModelType .TEXT_GENERATION )
82- await db_session .flush ()
83- _mock_provider_reachable (respx , base_url = "https://tei.example.com" )
84-
85- response = await client .post (
86- url = URL ,
87- headers = {"Authorization" : f"Bearer { self .token .token } " },
88- json = _valid_body (router .id , type = "tei" , url = "https://tei.example.com/" ),
89- )
90-
91- assert response .status_code == 400
92- assert response .json ().get ("detail" ) == "Invalid model provider type tei for text-generation router."
93-
94- @respx .mock
95- async def test_provider_not_reachable (self , client : AsyncClient , db_session ):
96- router = RouterSQLFactory (user = self .admin_user , type = ModelType .TEXT_GENERATION )
97- await db_session .flush ()
98- _mock_provider_unreachable (respx )
99-
100- response = await client .post (
101- url = URL ,
102- headers = {"Authorization" : f"Bearer { self .token .token } " },
103- json = _valid_body (router .id ),
104- )
105-
106- assert response .status_code == 424
107- assert response .json ().get ("detail" ) == "Model provider my-model not reachable."
108-
109- @respx .mock
110- async def test_provider_already_exists (self , client : AsyncClient , db_session ):
111- router = RouterSQLFactory (user = self .admin_user , type = ModelType .TEXT_GENERATION )
112- ProviderSQLFactory (
113- router = router ,
114- user = self .admin_user ,
115- url = DEFAULT_PROVIDER_URL ,
116- model_name = "my-model" ,
117- max_context_length = 4096 ,
118- vector_size = None ,
119- )
120- await db_session .flush ()
121- _mock_provider_reachable (respx )
122-
123- response = await client .post (
124- url = URL ,
125- headers = {"Authorization" : f"Bearer { self .token .token } " },
126- json = _valid_body (router .id ),
127- )
128- assert response .status_code == 409
129- assert response .json ().get ("detail" ) == "Model provider my-model for url http://my-test-provider/ already exists for router 4."
130-
131- @respx .mock
132- async def test_provider_mismatch_max_context_length (self , client : AsyncClient , db_session ):
133- router = RouterSQLFactory (user = self .admin_user , type = ModelType .TEXT_EMBEDDINGS_INFERENCE , name = "test_router" )
134- ProviderSQLFactory (
135- router = router ,
136- user = self .admin_user ,
137- url = "https://albert.api.etalab.gouv.fr/" ,
138- model_name = "my-model" ,
139- max_context_length = 4096 ,
140- vector_size = 1234 ,
141- )
142- await db_session .flush ()
143- _mock_provider_reachable (respx , max_context_length = 1234 , vector_size = 1234 )
144-
145- response = await client .post (
146- url = URL ,
147- headers = {"Authorization" : f"Bearer { self .token .token } " },
148- json = _valid_body (router .id ),
149- )
150-
151- assert response .status_code == 403
152- assert response .json ().get ("detail" ) == "Inconsistent max context length for test_router. Expected: 1234. Actual: 4096"
153-
154- @respx .mock
155- async def test_provider_mismatch_vector_size (self , client : AsyncClient , db_session ):
156- router = RouterSQLFactory (user = self .admin_user , type = ModelType .TEXT_GENERATION , name = "test_router" )
157- ProviderSQLFactory (
158- router = router ,
159- user = self .admin_user ,
160- url = "https://albert.api.etalab.gouv.fr/" ,
161- model_name = "my-model" ,
162- max_context_length = 4096 ,
163- vector_size = 1234 ,
164- )
165- await db_session .flush ()
166- _mock_provider_reachable (respx , max_context_length = 1234 , vector_size = 1234 )
83+ @pytest .mark .parametrize (
84+ "use_case_result,expected_status,expected_detail" ,
85+ [
86+ (RouterNotFoundError (router_id = 1 ), 404 , "Model router 1 not found." ),
87+ (
88+ InvalidProviderTypeError (provider_type = "tei" , router_type = "text-generation" ),
89+ 400 ,
90+ "Invalid model provider type tei for text-generation router." ,
91+ ),
92+ (ProviderNotReachableError (model_name = "my-model" ), 424 , "Model provider my-model not reachable." ),
93+ (
94+ ProviderAlreadyExistsError (model_name = "my-model" , url = DEFAULT_PROVIDER_URL , router_id = 1 ),
95+ 409 ,
96+ f"Model provider my-model for url { DEFAULT_PROVIDER_URL } already exists for router 1." ,
97+ ),
98+ (
99+ InconsistentModelMaxContextLengthError (expected_max_context_length = 4096 , actual_max_context_length = 2048 , router_name = "my-router" ),
100+ 403 ,
101+ "Inconsistent max context length for my-router. Expected: 4096. Actual: 2048" ,
102+ ),
103+ (
104+ InconsistentModelVectorSizeError (expected_vector_size = 768 , actual_vector_size = 384 , router_name = "my-router" ),
105+ 403 ,
106+ "Inconsistent vector size for my-router. Expected: 768. Actual: 384" ,
107+ ),
108+ ],
109+ )
110+ async def test_error_maps_to_correct_http_status (self , client : AsyncClient , app , use_case_result , expected_status , expected_detail ):
111+ mock_use_case = AsyncMock ()
112+ mock_use_case .execute .return_value = use_case_result
113+ app .dependency_overrides [create_provider_use_case_factory ] = lambda : mock_use_case
167114
168115 response = await client .post (
169116 url = URL ,
170117 headers = {"Authorization" : f"Bearer { self .token .token } " },
171- json = _valid_body (router . id ),
118+ json = _valid_body (router_id = 1 ),
172119 )
173120
174- assert response .status_code == 403
175- assert response .json ().get ("detail" ) == "Inconsistent vector size for test_router. Expected: None. Actual: 1234"
121+ assert response .status_code == expected_status
122+ assert response .json ().get ("detail" ) == expected_detail
176123
177- @respx .mock
178- async def test_router_not_found (self , client : AsyncClient , db_session ):
179- response = await client .post (
180- url = URL ,
181- headers = {"Authorization" : f"Bearer { self .token .token } " },
182- json = _valid_body (999999 ),
183- )
124+ @pytest .mark .parametrize (
125+ "headers,expected_status,expected_detail" ,
126+ [
127+ ({}, 401 , "Not authenticated" ),
128+ ({"Authorization" : "Bearer invalid-token" }, 403 , "Invalid API key." ),
129+ ],
130+ )
131+ async def test_auth (self , client : AsyncClient , headers , expected_status , expected_detail ):
132+ response = await client .post (url = URL , headers = headers , json = _valid_body (router_id = 1 ))
184133
185- assert response .status_code == 404
186- assert response .json ().get ("detail" ) == "Model router 999999 not found."
134+ assert response .status_code == expected_status
135+ assert response .json ().get ("detail" ) == expected_detail
0 commit comments