-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
Copy pathmain.py
147 lines (117 loc) · 5.2 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import logging
import os
import shutil
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from pathlib import Path
import sentry_sdk
import torch
import uvicorn
from fastapi import FastAPI
from prometheus_fastapi_instrumentator import Instrumentator
from sentry_sdk.integrations.fastapi import FastApiIntegration
from sentry_sdk.integrations.starlette import StarletteIntegration
from transformers import logging as transformer_logging # type:ignore
from model_server.custom_models import router as custom_models_router
from model_server.custom_models import warm_up_information_content_model
from model_server.custom_models import warm_up_intent_model
from model_server.encoders import router as encoders_router
from model_server.management_endpoints import router as management_router
from model_server.utils import get_gpu_type
from onyx import __version__
from onyx.utils.logger import setup_logger
from onyx.utils.logger import setup_uvicorn_logger
from onyx.utils.middleware import add_onyx_request_id_middleware
from onyx.utils.middleware import add_onyx_tenant_id_middleware
from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import MIN_THREADS_ML_MODELS
from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import SENTRY_DSN
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
HF_CACHE_PATH = Path(os.path.expanduser("~")) / ".cache/huggingface"
TEMP_HF_CACHE_PATH = Path(os.path.expanduser("~")) / ".cache/temp_huggingface"
transformer_logging.set_verbosity_error()
logger = setup_logger()
file_handlers = [
h for h in logger.logger.handlers if isinstance(h, logging.FileHandler)
]
setup_uvicorn_logger(shared_file_handlers=file_handlers)
def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -> None:
"""
This moves the files from the temp huggingface cache to the huggingface cache
We have to move each file individually because the directories might
have the same name but not the same contents and we dont want to remove
the files in the existing huggingface cache that don't exist in the temp
huggingface cache.
"""
for item in source.iterdir():
target_path = dest / item.relative_to(source)
if item.is_dir():
_move_files_recursively(item, target_path, overwrite)
else:
target_path.parent.mkdir(parents=True, exist_ok=True)
if target_path.exists() and not overwrite:
continue
shutil.move(str(item), str(target_path))
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator:
gpu_type = get_gpu_type()
logger.notice(f"Torch GPU Detection: gpu_type={gpu_type}")
app.state.gpu_type = gpu_type
try:
if TEMP_HF_CACHE_PATH.is_dir():
logger.notice("Moving contents of temp_huggingface to huggingface cache.")
_move_files_recursively(TEMP_HF_CACHE_PATH, HF_CACHE_PATH)
shutil.rmtree(TEMP_HF_CACHE_PATH, ignore_errors=True)
logger.notice("Moved contents of temp_huggingface to huggingface cache.")
except Exception as e:
logger.warning(
f"Error moving contents of temp_huggingface to huggingface cache: {e}. "
"This is not a critical error and the model server will continue to run."
)
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
logger.notice(f"Torch Threads: {torch.get_num_threads()}")
if not INDEXING_ONLY:
logger.notice(
"The intent model should run on the model server. The information content model should not run here."
)
warm_up_intent_model()
else:
logger.notice(
"The content information model should run on the indexing model server. The intent model should not run here."
)
warm_up_information_content_model()
yield
def get_model_app() -> FastAPI:
application = FastAPI(
title="Onyx Model Server", version=__version__, lifespan=lifespan
)
if SENTRY_DSN:
sentry_sdk.init(
dsn=SENTRY_DSN,
integrations=[StarletteIntegration(), FastApiIntegration()],
traces_sample_rate=0.1,
)
logger.info("Sentry initialized")
else:
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
application.include_router(management_router)
application.include_router(encoders_router)
application.include_router(custom_models_router)
request_id_prefix = "INF"
if INDEXING_ONLY:
request_id_prefix = "IDX"
add_onyx_tenant_id_middleware(application, logger)
add_onyx_request_id_middleware(application, request_id_prefix, logger)
# Initialize and instrument the app
Instrumentator().instrument(application).expose(application)
return application
app = get_model_app()
if __name__ == "__main__":
logger.notice(
f"Starting Onyx Model Server on http://{MODEL_SERVER_ALLOWED_HOST}:{str(MODEL_SERVER_PORT)}/"
)
logger.notice(f"Model Server Version: {__version__}")
uvicorn.run(app, host=MODEL_SERVER_ALLOWED_HOST, port=MODEL_SERVER_PORT)