Skip to content

Commit e1aeacb

Browse files
authored
chore(schemas): clean model schemas (provider and routers) (#783)
* chore(schemas): clean model schemas (provider and routers) * chore(actions): exclude node mudule from secrets scan
1 parent 133ebac commit e1aeacb

26 files changed

Lines changed: 288 additions & 359 deletions

File tree

.github/.trufflehog-exclude.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
docs/node_modules/
2+
node_modules/

.github/workflows/secrets_scan.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ jobs:
2020
extra_args: |
2121
--results=verified,unknown
2222
--exclude-detectors=Postgres
23+
--exclude-paths=./.github/.trufflehog-exclude.txt
2324
- name: Install git-secrets
2425
run: |
2526
git clone https://github.com/awslabs/git-secrets.git

api/domain/model/entities.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
1-
from enum import Enum
1+
from enum import StrEnum
22

3-
from pydantic import BaseModel, Field
3+
from pydantic import BaseModel
44

55

6-
class Metric(str, Enum):
6+
class Metric(StrEnum):
77
TTFT = "ttft" # time to first token
88
LATENCY = "latency" # requests latency
99
INFLIGHT = "inflight" # requests concurrency
1010
PERFORMANCE = "performance" # custom performance metric
1111

1212

1313
class ModelCosts(BaseModel):
14-
prompt_tokens: float = Field(default=0.0, ge=0.0, description="Cost of a million prompt tokens (decrease user budget)")
15-
completion_tokens: float = Field(default=0.0, ge=0.0, description="Cost of a million completion tokens (decrease user budget)")
14+
prompt_tokens: float = 0.0
15+
completion_tokens: float = 0.0
1616

1717

18-
class ModelType(str, Enum):
18+
class ModelType(StrEnum):
1919
AUTOMATIC_SPEECH_RECOGNITION = "automatic-speech-recognition"
2020
IMAGE_TEXT_TO_TEXT = "image-text-to-text"
2121
IMAGE_TO_TEXT = "image-to-text"
@@ -25,10 +25,10 @@ class ModelType(str, Enum):
2525

2626

2727
class Model(BaseModel):
28-
id: str = Field(..., description="The model identifier, which can be referenced in the API endpoints.")
29-
type: ModelType = Field(..., description="The type of the model, which can be used to identify the model type.", examples=["text-generation"]) # fmt: off
30-
aliases: list[str] | None = Field(default=None, description="Aliases of the model. It will be used to identify the model by users.", examples=[["model-alias", "model-alias-2"]]) # fmt: off
31-
created: int = Field(..., description="Time of creation, as Unix timestamp.")
32-
owned_by: str = Field(..., description="The organization that owns the model.")
33-
max_context_length: int | None = Field(default=None, description="Maximum amount of tokens a context could contains. Makes sure it is the same for all models.") # fmt: off
34-
costs: ModelCosts = Field(..., description="Costs of the model.")
28+
id: str
29+
type: ModelType
30+
aliases: list[str] = []
31+
created: int
32+
owned_by: str
33+
max_context_length: int | None = None
34+
costs: ModelCosts

api/domain/provider/_providerrepository.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,11 @@ async def create_provider(
2828
pass
2929

3030
@abstractmethod
31-
async def delete_provider(
32-
self,
33-
provider_id: int,
34-
) -> Provider | None:
31+
async def delete_provider(self, provider_id: int) -> Provider | None:
3532
pass
3633

3734
@abstractmethod
38-
async def get_one_provider(
39-
self,
40-
provider_id: int,
41-
) -> Provider | None:
35+
async def get_one_provider(self, provider_id: int) -> Provider | None:
4236
pass
4337

4438
@abstractmethod

api/domain/provider/entities.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
from enum import Enum, StrEnum
2-
from typing import Literal
1+
from enum import StrEnum
32

43
import pycountry
5-
from pydantic import Field, constr
64

75
from api.domain import EntitiesPage
86
from api.domain.model.entities import ModelType
@@ -12,8 +10,7 @@
1210

1311
# Add world as a country code, default value of the carbon footprint computation framework
1412
country_codes = [country.alpha_3 for country in pycountry.countries] + ["WOR"]
15-
country_codes_dict = {str(code).upper(): str(code) for code in sorted(set(country_codes))}
16-
ProviderCarbonFootprintZone: type[Enum] = Enum("ProviderCarbonFootprintZone", country_codes_dict, type=str)
13+
ProviderCarbonFootprintZone = StrEnum("ProviderCarbonFootprintZone", {str(code).upper(): str(code) for code in sorted(set(country_codes))})
1714

1815

1916
class ProviderType(StrEnum):
@@ -72,22 +69,21 @@ class ProviderSortField(StrEnum):
7269

7370

7471
class Provider(BaseModel):
75-
object: Literal["provider"] = "provider"
76-
id: int = Field(..., description="Provider ID.") # fmt: off
77-
router_id: int = Field(..., description="ID of the router that owns the provider.") # fmt: off
78-
user_id: int = Field(..., description="ID of the user that owns the provider.") # fmt: off
79-
type: ProviderType = Field(..., description="Provider type.") # fmt: off
80-
url: constr(strip_whitespace=True, min_length=1, to_lower=True) | None = Field(default=None, description="Provider API url. The url must only contain the domain name (without `/v1` suffix for example).") # fmt: off
81-
key: str | None = Field(description="Provider API key.") # fmt: off
82-
timeout: int = Field(..., description="Timeout for the provider requests, after user receive an 500 error (model is too busy).") # fmt: off
83-
model_name: str = Field(..., description="Model name from the model provider.") # fmt: off
84-
model_hosting_zone: ProviderCarbonFootprintZone = Field(default=ProviderCarbonFootprintZone.WOR, description="Model hosting zone using ISO 3166-1 alpha-3 code format (e.g., `WOR` for World, `FRA` for France, `USA` for United States). This determines the electricity mix used for carbon intensity calculations. For more information, see https://ecologits.ai", examples=["WOR"]) # fmt: off
85-
model_total_params: int = Field(default=0, ge=0, description="Total params of the model in billions of parameters for carbon footprint computation. If not provided, the active params will be used if provided, else carbon footprint will not be computed. For more information, see https://ecologits.ai") # fmt: off
86-
model_active_params: int = Field(default=0, ge=0, description="Active params of the model in billions of parameters for carbon footprint computation. If not provided, the total params will be used if provided, else carbon footprint will not be computed. For more information, see https://ecologits.ai") # fmt: off
87-
qos_metric: Metric | None = Field(description="The metric to use for the QoS policy. If not provided, no QoS policy is applied.") # fmt: off
88-
qos_limit: float | None = Field(default=None, ge=0.0, description="The value to use for the quality of service. Depends of the metric, the value can be a percentile, a threshold, etc.") # fmt: off
89-
created: int | None = Field(default=None, description="Time of creation, as Unix timestamp.") # fmt: off
90-
updated: int | None = Field(default=None, description="Time of last update, as Unix timestamp.") # fmt: off
72+
id: int
73+
router_id: int
74+
user_id: int
75+
type: ProviderType
76+
url: str
77+
key: str | None = None
78+
timeout: int
79+
model_name: str
80+
model_hosting_zone: ProviderCarbonFootprintZone = ProviderCarbonFootprintZone.WOR
81+
model_total_params: int = 0
82+
model_active_params: int = 0
83+
qos_metric: Metric | None = None
84+
qos_limit: float | None = None
85+
created: int
86+
updated: int
9187

9288
def with_router_id(self, router_id: int) -> "Provider":
9389
return self.model_copy(update={"router_id": router_id})

api/domain/router/entities.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from enum import StrEnum
22

3-
from pydantic import BaseModel, Field
3+
from pydantic import BaseModel
44

55
from api.domain import EntitiesPage
66
from api.domain.model import ModelType as RouterType
@@ -15,19 +15,19 @@ class RouterLoadBalancingStrategy(StrEnum):
1515

1616

1717
class Router(BaseModel):
18-
id: int = Field(..., description="ID of the router.") # fmt: off
19-
name: str = Field(..., description="Name of the router.") # fmt: off
20-
user_id: int = Field(..., description="ID of the user that owns the router.") # fmt: off
21-
type: RouterType = Field(..., description="Type of the model router. It will be used to identify the model router type.", examples=["text-generation"]) # fmt: off
22-
aliases: list[str] | None = Field(default=None, description="Aliases of the model. It will be used to identify the model by users.", examples=[["model-alias", "model-alias-2"]]) # fmt: off
23-
load_balancing_strategy: RouterLoadBalancingStrategy = Field(..., description="Routing strategy for load balancing between providers of the model. It will be used to identify the model type.", examples=["least_busy"]) # fmt: off
24-
vector_size: int | None = Field(default=None, description="Dimension of the vectors, if the models are embeddings. Make sure it is the same for all models.") # fmt: off
25-
max_context_length: int | None = Field(default=None, description="Maximum amount of tokens a context could contains. Make sure it is the same for all models.") # fmt: off
26-
cost_prompt_tokens: float = Field(description="Cost of a million prompt tokens (decrease user budget)")
27-
cost_completion_tokens: float = Field(description="Cost of a million completion tokens (decrease user budget)")
28-
providers: int = Field(default=0, description="Number of providers in the router.") # fmt: off
29-
created: int = Field(..., description="Time of creation, as Unix timestamp.") # fmt: off
30-
updated: int = Field(..., description="Time of last update, as Unix timestamp.") # fmt: off
18+
id: int
19+
name: str
20+
user_id: int
21+
type: RouterType
22+
aliases: list[str] | None
23+
load_balancing_strategy: RouterLoadBalancingStrategy
24+
vector_size: int | None
25+
max_context_length: int | None
26+
cost_prompt_tokens: float
27+
cost_completion_tokens: float
28+
providers: int
29+
created: int
30+
updated: int
3131

3232
def with_name(self, name: str) -> "Router":
3333
return self.model_copy(update={"name": name})

0 commit comments

Comments
 (0)