Skip to content

Commit e278d06

Browse files
authored
Merge pull request #149 from NillionNetwork/feat/env_model_settings
feat: added model settings
2 parents b30dafd + 05dde87 commit e278d06

File tree

3 files changed

+21
-5
lines changed

3 files changed

+21
-5
lines changed

nilai-models/src/nilai_models/daemon.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from nilai_common import ( # Model service discovery and host settings
88
SETTINGS,
9+
MODEL_SETTINGS,
910
ModelServiceDiscovery,
1011
ModelEndpoint,
1112
ModelMetadata,
@@ -14,7 +15,7 @@
1415
logger = logging.getLogger(__name__)
1516

1617

17-
async def get_metadata(num_retries=30):
18+
async def get_metadata():
1819
"""Fetch model metadata from model
1920
service and return as ModelMetadata object"""
2021
current_retries = 0
@@ -46,9 +47,13 @@ async def get_metadata(num_retries=30):
4647
else:
4748
logger.warning(f"Failed to fetch model metadata from {url}: {e}")
4849
current_retries += 1
49-
if current_retries >= num_retries:
50+
if (
51+
MODEL_SETTINGS.num_retries
52+
!= -1 # If num_retries == -1 then we do infinite number of retries
53+
and current_retries >= MODEL_SETTINGS.num_retries
54+
):
5055
raise e
51-
await asyncio.sleep(10)
56+
await asyncio.sleep(MODEL_SETTINGS.timeout)
5257

5358

5459
async def run_service(discovery_service, model_endpoint):

packages/nilai-common/src/nilai_common/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
Message,
1717
MessageAdapter,
1818
)
19-
from nilai_common.config import SETTINGS
19+
from nilai_common.config import SETTINGS, MODEL_SETTINGS
2020
from nilai_common.discovery import ModelServiceDiscovery
2121
from openai.types.completion_usage import CompletionUsage as Usage
2222

@@ -36,6 +36,7 @@
3636
"AMDAttestationToken",
3737
"NVAttestationToken",
3838
"SETTINGS",
39+
"MODEL_SETTINGS",
3940
"SearchResult",
4041
"Source",
4142
"WebSearchEnhancedMessages",

packages/nilai-common/src/nilai_common/config.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from pydantic import BaseModel
2+
from pydantic import BaseModel, Field
33

44

55
class HostSettings(BaseModel):
@@ -13,6 +13,11 @@ class HostSettings(BaseModel):
1313
attestation_port: int = 8081
1414

1515

16+
class ModelSettings(BaseModel):
17+
num_retries: int = Field(default=30, ge=-1)
18+
timeout: int = Field(default=10, ge=1)
19+
20+
1621
SETTINGS: HostSettings = HostSettings(
1722
host=str(os.getenv("SVC_HOST", "localhost")),
1823
port=int(os.getenv("SVC_PORT", 8000)),
@@ -23,3 +28,8 @@ class HostSettings(BaseModel):
2328
attestation_host=str(os.getenv("ATTESTATION_HOST", "localhost")),
2429
attestation_port=int(os.getenv("ATTESTATION_PORT", 8081)),
2530
)
31+
32+
MODEL_SETTINGS: ModelSettings = ModelSettings(
33+
num_retries=int(os.getenv("MODEL_NUM_RETRIES", 30)),
34+
timeout=int(os.getenv("MODEL_RETRY_TIMEOUT", 10)),
35+
)

0 commit comments

Comments
 (0)