|
1 | | -from enum import Enum, StrEnum |
2 | | -from typing import Literal |
| 1 | +from enum import StrEnum |
3 | 2 |
|
4 | 3 | import pycountry |
5 | | -from pydantic import Field, constr |
6 | 4 |
|
7 | 5 | from api.domain import EntitiesPage |
8 | 6 | from api.domain.model.entities import ModelType |
|
12 | 10 |
|
13 | 11 | # Add world as a country code, default value of the carbon footprint computation framework |
14 | 12 | country_codes = [country.alpha_3 for country in pycountry.countries] + ["WOR"] |
15 | | -country_codes_dict = {str(code).upper(): str(code) for code in sorted(set(country_codes))} |
16 | | -ProviderCarbonFootprintZone: type[Enum] = Enum("ProviderCarbonFootprintZone", country_codes_dict, type=str) |
| 13 | +ProviderCarbonFootprintZone = StrEnum("ProviderCarbonFootprintZone", {str(code).upper(): str(code) for code in sorted(set(country_codes))}) |
17 | 14 |
|
18 | 15 |
|
19 | 16 | class ProviderType(StrEnum): |
@@ -72,22 +69,21 @@ class ProviderSortField(StrEnum): |
72 | 69 |
|
73 | 70 |
|
74 | 71 | class Provider(BaseModel): |
75 | | - object: Literal["provider"] = "provider" |
76 | | - id: int = Field(..., description="Provider ID.") # fmt: off |
77 | | - router_id: int = Field(..., description="ID of the router that owns the provider.") # fmt: off |
78 | | - user_id: int = Field(..., description="ID of the user that owns the provider.") # fmt: off |
79 | | - type: ProviderType = Field(..., description="Provider type.") # fmt: off |
80 | | - url: constr(strip_whitespace=True, min_length=1, to_lower=True) | None = Field(default=None, description="Provider API url. The url must only contain the domain name (without `/v1` suffix for example).") # fmt: off |
81 | | - key: str | None = Field(description="Provider API key.") # fmt: off |
82 | | - timeout: int = Field(..., description="Timeout for the provider requests, after user receive an 500 error (model is too busy).") # fmt: off |
83 | | - model_name: str = Field(..., description="Model name from the model provider.") # fmt: off |
84 | | - model_hosting_zone: ProviderCarbonFootprintZone = Field(default=ProviderCarbonFootprintZone.WOR, description="Model hosting zone using ISO 3166-1 alpha-3 code format (e.g., `WOR` for World, `FRA` for France, `USA` for United States). This determines the electricity mix used for carbon intensity calculations. For more information, see https://ecologits.ai", examples=["WOR"]) # fmt: off |
85 | | - model_total_params: int = Field(default=0, ge=0, description="Total params of the model in billions of parameters for carbon footprint computation. If not provided, the active params will be used if provided, else carbon footprint will not be computed. For more information, see https://ecologits.ai") # fmt: off |
86 | | - model_active_params: int = Field(default=0, ge=0, description="Active params of the model in billions of parameters for carbon footprint computation. If not provided, the total params will be used if provided, else carbon footprint will not be computed. For more information, see https://ecologits.ai") # fmt: off |
87 | | - qos_metric: Metric | None = Field(description="The metric to use for the QoS policy. If not provided, no QoS policy is applied.") # fmt: off |
88 | | - qos_limit: float | None = Field(default=None, ge=0.0, description="The value to use for the quality of service. Depends of the metric, the value can be a percentile, a threshold, etc.") # fmt: off |
89 | | - created: int | None = Field(default=None, description="Time of creation, as Unix timestamp.") # fmt: off |
90 | | - updated: int | None = Field(default=None, description="Time of last update, as Unix timestamp.") # fmt: off |
| 72 | + id: int |
| 73 | + router_id: int |
| 74 | + user_id: int |
| 75 | + type: ProviderType |
| 76 | + url: str |
| 77 | + key: str | None = None |
| 78 | + timeout: int |
| 79 | + model_name: str |
| 80 | + model_hosting_zone: ProviderCarbonFootprintZone = ProviderCarbonFootprintZone.WOR |
| 81 | + model_total_params: int = 0 |
| 82 | + model_active_params: int = 0 |
| 83 | + qos_metric: Metric | None = None |
| 84 | + qos_limit: float | None = None |
| 85 | + created: int |
| 86 | + updated: int |
91 | 87 |
|
92 | 88 | def with_router_id(self, router_id: int) -> "Provider": |
93 | 89 | return self.model_copy(update={"router_id": router_id}) |
|
0 commit comments