Skip to content

Commit 1da5c3b

Browse files
author
Tibo Pendino
committed
feat(vector_store): add lifespan management for vector store without configuration
- update RouterTable - add SQLAlchemy migration - add stronger typing - add default management with type constraint - add is_default property for create and update routers - update tests (not tested yet)
1 parent 96ba8a0 commit 1da5c3b

11 files changed

Lines changed: 129 additions & 47 deletions

File tree

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""add is_default property and partial unique constraint on is_default property for models
2+
3+
Revision ID: 4ad3859d4c64
4+
Revises: f02a2525b97c
5+
Create Date: 2026-02-02 22:04:55.231167
6+
7+
"""
8+
from typing import Sequence, Union
9+
10+
from alembic import op
11+
import sqlalchemy as sa
12+
import logging
13+
14+
15+
# revision identifiers, used by Alembic.
16+
revision: str = '4ad3859d4c64'
17+
down_revision: Union[str, None] = 'f02a2525b97c'
18+
branch_labels: Union[str, Sequence[str], None] = None
19+
depends_on: Union[str, Sequence[str], None] = None
20+
logger = logging.getLogger(__name__)
21+
22+
23+
def upgrade() -> None:
24+
"""Upgrade schema."""
25+
# ### commands auto generated by Alembic - please adjust! ###
26+
logger.warning("Upgrade: adding 'is_default' column to 'router' with server_default=FALSE for existing rows")
27+
op.add_column(
28+
'router',
29+
sa.Column('is_default', sa.Boolean(), nullable=False, server_default=sa.text('FALSE')),
30+
)
31+
logger.warning("Upgrade: creating partial unique index 'unique_default_per_model_type' on 'router(type)' where is_default IS TRUE")
32+
op.create_index(
33+
'unique_default_per_model_type',
34+
'router',
35+
['type'],
36+
unique=True,
37+
postgresql_where=sa.text('is_default IS TRUE'),
38+
)
39+
logger.warning("Upgrade: finished")
40+
# ### end Alembic commands ###
41+
42+
43+
def downgrade() -> None:
44+
"""Downgrade schema."""
45+
# ### commands auto generated by Alembic - please adjust! ###
46+
logger.warning("Downgrade: dropping partial unique index 'unique_default_per_model_type'")
47+
op.drop_index('unique_default_per_model_type', table_name='router', postgresql_where=sa.text('is_default IS TRUE'))
48+
logger.warning("Downgrade: dropping column 'is_default' from 'router'")
49+
op.drop_column('router', 'is_default')
50+
logger.warning("Downgrade: finished")
51+
# ### end Alembic commands ###

api/endpoints/admin/routers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ async def create_router(
3030
type=body.type,
3131
aliases=body.aliases,
3232
load_balancing_strategy=body.load_balancing_strategy,
33+
is_default=body.is_default,
3334
cost_prompt_tokens=body.cost_prompt_tokens,
3435
cost_completion_tokens=body.cost_completion_tokens,
3536
user_id=request_context.get().user_info.id,
@@ -78,6 +79,7 @@ async def update_router(
7879
type=body.type,
7980
aliases=body.aliases,
8081
load_balancing_strategy=body.load_balancing_strategy,
82+
is_default=body.is_default,
8183
cost_prompt_tokens=body.cost_prompt_tokens,
8284
cost_completion_tokens=body.cost_completion_tokens,
8385
postgres_session=postgres_session,

api/endpoints/ocr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ async def ocr_beta(
8484
"""
8585
Extracts text from PDF files using OCR.
8686
"""
87-
# check if file is a pdf (raises UnsupportedFileTypeException if not a PDF)
87+
# check if file is a PDF (raises UnsupportedFileTypeException if not a PDF)
8888
global_context.document_manager.parser_manager._detect_file_type(file=file, type=FileType.PDF)
8989

9090
# check file size

api/helpers/models/_modelregistry.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ async def setup(self, models: list[ModelConfiguration], postgres_session: AsyncS
122122
type=model.type,
123123
aliases=model.aliases,
124124
load_balancing_strategy=model.load_balancing_strategy,
125+
is_default=False,
125126
cost_prompt_tokens=model.cost_prompt_tokens,
126127
cost_completion_tokens=model.cost_completion_tokens,
127128
user_id=0, # setup as master user
@@ -181,6 +182,7 @@ async def create_router(
181182
type: ModelType,
182183
aliases: list[str],
183184
load_balancing_strategy: RouterLoadBalancingStrategy,
185+
is_default: bool,
184186
cost_prompt_tokens: float,
185187
cost_completion_tokens: float,
186188
user_id: int,
@@ -194,6 +196,7 @@ async def create_router(
194196
type(ModelType): The type of model
195197
aliases(List[str]): List of aliases for the model
196198
load_balancing_strategy(RouterLoadBalancingStrategy): The routing strategy to use
199+
is_default(bool): Whether the router is default for its type
197200
cost_prompt_tokens(float): The cost of a million prompt tokens
198201
cost_completion_tokens(float): The cost of a million completion tokens
199202
user_id(int): The user ID of owner of the router
@@ -213,6 +216,7 @@ async def create_router(
213216
name=name,
214217
type=type.value,
215218
load_balancing_strategy=load_balancing_strategy.value,
219+
is_default=is_default,
216220
cost_prompt_tokens=cost_prompt_tokens,
217221
cost_completion_tokens=cost_completion_tokens,
218222
)
@@ -276,6 +280,7 @@ async def update_router(
276280
type: ModelType | None,
277281
aliases: list[str] | None,
278282
load_balancing_strategy: RouterLoadBalancingStrategy | None,
283+
is_default: bool | None,
279284
cost_prompt_tokens: float | None,
280285
cost_completion_tokens: float | None,
281286
postgres_session: AsyncSession,
@@ -289,6 +294,7 @@ async def update_router(
289294
type(Optional[ModelType]): Optional new type
290295
aliases(Optional[List[str]]): Optional new aliases list (replaces existing)
291296
load_balancing_strategy(Optional[RouterLoadBalancingStrategy]): Optional new routing strategy
297+
is_default(Optional[bool]): Optional new is_default flag (one True per type)
292298
cost_prompt_tokens(Optional[float]): Optional new cost of a million prompt tokens
293299
cost_completion_tokens(Optional[float]): Optional new cost of a million completion tokens
294300
postgres_session(AsyncSession): Database postgres_session
@@ -311,6 +317,8 @@ async def update_router(
311317
update_values["type"] = type.value
312318
if load_balancing_strategy is not None:
313319
update_values["load_balancing_strategy"] = load_balancing_strategy.value
320+
if is_default is not None:
321+
update_values["is_default"] = is_default
314322
if name is not None:
315323
update_values["name"] = name
316324
if cost_prompt_tokens is not None:
@@ -330,6 +338,7 @@ async def update_router(
330338
await postgres_session.execute(query)
331339

332340
await postgres_session.commit()
341+
# TODO: Make the update method return the updated router
333342

334343
@staticmethod
335344
async def get_routers(
@@ -379,6 +388,7 @@ async def get_routers(
379388
RouterTable.user_id,
380389
RouterTable.type,
381390
RouterTable.load_balancing_strategy,
391+
RouterTable.is_default,
382392
RouterTable.cost_prompt_tokens,
383393
RouterTable.cost_completion_tokens,
384394
first_provider_subquery.c.max_context_length,
@@ -429,6 +439,7 @@ async def get_routers(
429439
type=ModelType(row["type"]),
430440
aliases=aliases.get(row["id"], []),
431441
load_balancing_strategy=RouterLoadBalancingStrategy(row["load_balancing_strategy"]),
442+
is_default=row["is_default"],
432443
vector_size=row["vector_size"],
433444
max_context_length=row["max_context_length"],
434445
cost_prompt_tokens=row["cost_prompt_tokens"] or 0.0,

api/schemas/admin/routers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class CreateRouter(BaseModel):
1717
type: ModelType = Field(..., description="Type of the model router. It will be used to identify the model router type.", examples=["text-generation"]) # fmt: off
1818
aliases: list[constr(strip_whitespace=True, min_length=1, max_length=64)] = Field(default_factory=list, description="Aliases of the model. It will be used to identify the model by users.", examples=[["model-alias", "model-alias-2"]]) # fmt: off
1919
load_balancing_strategy: RouterLoadBalancingStrategy = Field(default=RouterLoadBalancingStrategy.SHUFFLE, description="Routing strategy for load balancing between providers of the model. It will be used to identify the model type.", examples=["least_busy"]) # fmt: off
20+
is_default: bool = Field(default=False, description="Whether the router is the default one for its type.")
2021
cost_prompt_tokens: float = Field(default=0.0, ge=0.0, description="Cost of a million prompt tokens (decrease user budget)")
2122
cost_completion_tokens: float = Field(default=0.0, ge=0.0, description="Cost of a million completion tokens (decrease user budget)")
2223

@@ -30,6 +31,7 @@ class UpdateRouter(BaseModel):
3031
type: ModelType | None = Field(default=None, description="Type of the model router. It will be used to identify the model router type.", examples=["text-generation"]) # fmt: off
3132
aliases: list[constr(strip_whitespace=True, min_length=1, max_length=64)] | None = Field(default=None, description="Aliases of the model. It will be used to identify the model by users.", examples=[["model-alias", "model-alias-2"]]) # fmt: off
3233
load_balancing_strategy: RouterLoadBalancingStrategy | None = Field(default=None, description="Routing strategy for load balancing between providers of the model. It will be used to identify the model type.", examples=["least_busy"]) # fmt: off
34+
is_default: bool | None = Field(default=None, description="Whether the router is the default one for its type.")
3335
cost_prompt_tokens: float | None = Field(default=None, ge=0.0, description="Cost of a million prompt tokens (decrease user budget)")
3436
cost_completion_tokens: float | None = Field(default=None, ge=0.0, description="Cost of a million completion tokens (decrease user budget)")
3537

@@ -42,6 +44,7 @@ class Router(BaseModel):
4244
type: ModelType = Field(..., description="Type of the model router. It will be used to identify the model router type.", examples=["text-generation"]) # fmt: off
4345
aliases: list[str] | None = Field(default=None, description="Aliases of the model. It will be used to identify the model by users.", examples=[["model-alias", "model-alias-2"]]) # fmt: off
4446
load_balancing_strategy: RouterLoadBalancingStrategy = Field(..., description="Routing strategy for load balancing between providers of the model. It will be used to identify the model type.", examples=["least_busy"]) # fmt: off
47+
is_default: bool = Field(..., description="Whether the router is the default one for its type.")
4548
vector_size: int | None = Field(default=None, description="Dimension of the vectors, if the models are embeddings. Make sure it is the same for all models.") # fmt: off
4649
max_context_length: int | None = Field(default=None, description="Maximum amount of tokens a context could contains. Make sure it is the same for all models.") # fmt: off
4750
cost_prompt_tokens: float = Field(description="Cost of a million prompt tokens (decrease user budget)")

api/schemas/core/configuration.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -361,9 +361,6 @@ class Settings(ConfigBaseModel):
361361
monitoring_postgres_enabled: bool = Field(default=True, description="If true, the log usage will be written in the PostgreSQL database.") # fmt: off
362362
monitoring_prometheus_enabled: bool = Field(default=True, description="If true, Prometheus metrics will be exposed in the `/metrics` endpoint.") # fmt: off
363363

364-
# vector store
365-
vector_store_model: str | None = Field(default=None, description="Model used to vectorize the text in the vector store database. Is required if a vector store dependency is provided (Elasticsearch or Qdrant). This model must be defined in the `models` section and have type `text-embeddings-inference`.") # fmt: off
366-
367364
# postgres_session
368365
session_secret_key: str | None = Field(default=None, description='Secret key for postgres_session middleware. If not provided, the master key will be used.', examples=["knBnU1foGtBEwnOGTOmszldbSwSYLTcE6bdibC8bPGM"]) # fmt: off
369366

@@ -420,11 +417,6 @@ def validate_models(self) -> Any:
420417
if duplicated_models:
421418
raise ValueError(f"Duplicated model or alias names found: {", ".join(set(duplicated_models))}")
422419

423-
# check for interdependencies
424-
if self.dependencies.vector_store and self.settings.vector_store_model:
425-
assert self.settings.vector_store_model in models["all"], "Vector store model must be defined in models section."
426-
assert self.settings.vector_store_model in models[ModelType.TEXT_EMBEDDINGS_INFERENCE.value], f"The vector store model must have type {ModelType.TEXT_EMBEDDINGS_INFERENCE}." # fmt: off
427-
428420
return self
429421

430422

api/schemas/search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class SearchMethod(str, Enum):
1919

2020
class SearchArgs(BaseModel):
2121
collections: list[int] = Field(min_items=1, description="List of collections ID")
22-
rff_k: int = Field(default=20, description="k constant in RFF algorithm")
22+
rff_k: int = Field(default=20, description="k constant in RFF algorithm") # TO FIX: Does this allow zero or negative? Risk of invalid value IMO.
2323
k: int = Field(gt=0, le=200, default=10, deprecated=True, description="[DEPRECATED: use limit instead]Number of results to return")
2424
limit: int = Field(gt=0, le=200, default=10, description="Number of results to return")
2525
offset: int = Field(ge=0, default=0, description="Offset for pagination, specifying how many results to skip from the beginning")

api/sql/models.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from http import HTTPMethod
33
from typing import Optional
44

5-
from sqlalchemy import ForeignKey, UniqueConstraint, func
5+
from sqlalchemy import ForeignKey, Index, UniqueConstraint, func, text
66
from sqlalchemy.orm import Mapped, declarative_base, mapped_column, relationship
77

88
from api.schemas.admin.providers import ProviderCarbonFootprintZone, ProviderType
@@ -191,11 +191,21 @@ class Router(Base):
191191
name: Mapped[str] = mapped_column(unique=True)
192192
type: Mapped[ModelType]
193193
load_balancing_strategy: Mapped[RouterLoadBalancingStrategy]
194+
is_default: Mapped[bool] = mapped_column(default=False)
194195
cost_prompt_tokens: Mapped[float] = mapped_column(default=0.0)
195196
cost_completion_tokens: Mapped[float] = mapped_column(default=0.0)
196197
created: Mapped[dt.datetime] = mapped_column(insert_default=func.now())
197198
updated: Mapped[dt.datetime] = mapped_column(insert_default=func.now(), onupdate=func.now())
198199

200+
__table_args__ = (
201+
Index(
202+
"unique_default_per_model_type",
203+
"type",
204+
unique=True,
205+
postgresql_where=text("is_default IS TRUE"),
206+
),
207+
)
208+
199209
user: Mapped["User"] = relationship(back_populates="router")
200210
alias: Mapped[list["RouterAlias"]] = relationship(back_populates="router", cascade="all, delete-orphan", passive_deletes=True)
201211
provider: Mapped[list["Provider"]] = relationship(back_populates="router", cascade="all, delete-orphan", passive_deletes=True)

api/tests/unit/test_helpers/test_modelregistry/test_routers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ async def test_create_router_success(postgres_session: AsyncSession, model_regis
7878
type=ModelType.TEXT_GENERATION,
7979
aliases=["alias1", "alias2"],
8080
load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE,
81+
is_default=False,
8182
cost_prompt_tokens=1.0,
8283
cost_completion_tokens=2.0,
8384
user_id=1,
@@ -104,6 +105,7 @@ async def test_create_router_master_user(postgres_session: AsyncSession, model_r
104105
type=ModelType.TEXT_GENERATION,
105106
aliases=[],
106107
load_balancing_strategy=RouterLoadBalancingStrategy.LEAST_BUSY,
108+
is_default=False,
107109
cost_prompt_tokens=0.0,
108110
cost_completion_tokens=0.0,
109111
user_id=0, # master user
@@ -124,6 +126,7 @@ async def test_create_router_already_exists(postgres_session: AsyncSession, mode
124126
type=ModelType.TEXT_GENERATION,
125127
aliases=[],
126128
load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE,
129+
is_default=False,
127130
cost_prompt_tokens=0.0,
128131
cost_completion_tokens=0.0,
129132
user_id=1,
@@ -146,6 +149,7 @@ async def test_create_router_alias_already_exists(postgres_session: AsyncSession
146149
type=ModelType.TEXT_GENERATION,
147150
aliases=["existing-alias"],
148151
load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE,
152+
is_default=False,
149153
cost_prompt_tokens=0.0,
150154
cost_completion_tokens=0.0,
151155
user_id=1,
@@ -198,6 +202,7 @@ async def test_update_router_success_all_fields(postgres_session: AsyncSession,
198202
type=ModelType.TEXT_GENERATION,
199203
aliases=["old-alias"],
200204
load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE,
205+
is_default=False,
201206
vector_size=None,
202207
max_context_length=4096,
203208
cost_prompt_tokens=1.0,
@@ -220,6 +225,7 @@ async def test_update_router_success_all_fields(postgres_session: AsyncSession,
220225
type=ModelType.TEXT_EMBEDDINGS_INFERENCE,
221226
aliases=["new-alias1", "new-alias2"],
222227
load_balancing_strategy=RouterLoadBalancingStrategy.LEAST_BUSY,
228+
is_default=None,
223229
cost_prompt_tokens=3.0,
224230
cost_completion_tokens=4.0,
225231
postgres_session=postgres_session,
@@ -237,6 +243,7 @@ async def test_update_router_alias_conflict(postgres_session: AsyncSession, mode
237243
type=ModelType.TEXT_GENERATION,
238244
aliases=[],
239245
load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE,
246+
is_default=False,
240247
vector_size=None,
241248
max_context_length=4096,
242249
cost_prompt_tokens=0.0,
@@ -260,6 +267,7 @@ async def test_update_router_alias_conflict(postgres_session: AsyncSession, mode
260267
type=None,
261268
aliases=["conflicting-alias"],
262269
load_balancing_strategy=None,
270+
is_default=None,
263271
cost_prompt_tokens=None,
264272
cost_completion_tokens=None,
265273
postgres_session=postgres_session,
@@ -275,6 +283,7 @@ async def test_update_router_noop(postgres_session: AsyncSession, model_registry
275283
type=ModelType.TEXT_GENERATION,
276284
aliases=[],
277285
load_balancing_strategy=RouterLoadBalancingStrategy.SHUFFLE,
286+
is_default=False,
278287
vector_size=None,
279288
max_context_length=4096,
280289
cost_prompt_tokens=0.0,
@@ -293,6 +302,7 @@ async def test_update_router_noop(postgres_session: AsyncSession, model_registry
293302
type=None,
294303
aliases=None,
295304
load_balancing_strategy=None,
305+
is_default=None,
296306
cost_prompt_tokens=None,
297307
cost_completion_tokens=None,
298308
postgres_session=postgres_session,

api/utils/dependencies.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1+
from collections.abc import AsyncGenerator, AsyncIterator
12
from contextvars import ContextVar
23

34
import redis.asyncio as redis
4-
from redis.asyncio import Redis as AsyncRedis
5-
from sqlalchemy.ext.asyncio import AsyncSession
65

76
from api.helpers._usagemanager import UsageManager
87
from api.helpers.models import ModelRegistry
@@ -32,7 +31,7 @@ def get_model_registry() -> ModelRegistry:
3231
return global_context.model_registry
3332

3433

35-
async def get_redis_client() -> AsyncRedis:
34+
async def get_redis_client() -> AsyncGenerator:
3635
"""
3736
Get a Redis client built from the shared connection pool.
3837
@@ -47,7 +46,7 @@ async def get_redis_client() -> AsyncRedis:
4746
await client.aclose()
4847

4948

50-
async def get_postgres_session() -> AsyncSession:
49+
async def get_postgres_session() -> AsyncIterator:
5150
"""
5251
Get a PostgreSQL postgres_session from the global context.
5352

0 commit comments

Comments
 (0)