1
1
import os
2
+ import threading
2
3
from pathlib import Path
3
4
from time import perf_counter
4
5
16
17
HF_REVISION ,
17
18
HF_TASK ,
18
19
)
20
+ from huggingface_inference_toolkit .env_utils import api_inference_compat
19
21
from huggingface_inference_toolkit .handler import (
20
22
get_inference_handler_either_custom_or_default_handler ,
21
23
)
28
30
)
29
31
from huggingface_inference_toolkit .vertex_ai_utils import _load_repository_from_gcs
30
32
33
+ INFERENCE_HANDLERS = {}
34
+ INFERENCE_HANDLERS_LOCK = threading .Lock ()
31
35
32
36
async def prepare_model_artifacts ():
33
- global inference_handler
37
+ global INFERENCE_HANDLERS
34
38
# 1. check if model artifacts available in HF_MODEL_DIR
35
39
if len (list (Path (HF_MODEL_DIR ).glob ("**/*" ))) <= 0 :
36
40
# 2. if not available, try to load from HF_MODEL_ID
@@ -62,6 +66,7 @@ async def prepare_model_artifacts():
62
66
inference_handler = get_inference_handler_either_custom_or_default_handler (
63
67
HF_MODEL_DIR , task = HF_TASK
64
68
)
69
+ INFERENCE_HANDLERS [HF_TASK ] = inference_handler
65
70
logger .info ("Model initialized successfully" )
66
71
67
72
@@ -82,6 +87,7 @@ async def metrics(request):
82
87
83
88
84
89
async def predict (request ):
90
+ global INFERENCE_HANDLERS
85
91
try :
86
92
# extracts content from request
87
93
content_type = request .headers .get ("content-Type" , os .environ .get ("DEFAULT_CONTENT_TYPE" )).lower ()
@@ -101,6 +107,17 @@ async def predict(request):
101
107
dict (request .query_params )
102
108
)
103
109
110
+ # We lazily load pipelines for alt tasks
111
+ task = request .path_params .get ("task" , HF_TASK )
112
+ inference_handler = INFERENCE_HANDLERS .get (task )
113
+ if not inference_handler :
114
+ with INFERENCE_HANDLERS_LOCK .acquire ():
115
+ if task not in INFERENCE_HANDLERS :
116
+ inference_handler = get_inference_handler_either_custom_or_default_handler (
117
+ HF_MODEL_DIR , task = task )
118
+ INFERENCE_HANDLERS [task ] = inference_handler
119
+ else :
120
+ inference_handler = INFERENCE_HANDLERS [task ]
104
121
# tracks request time
105
122
start_time = perf_counter ()
106
123
# run async not blocking call
@@ -149,14 +166,19 @@ async def predict(request):
149
166
on_startup = [prepare_model_artifacts ],
150
167
)
151
168
else :
169
+ routes = [
170
+ Route ("/" , health , methods = ["GET" ]),
171
+ Route ("/health" , health , methods = ["GET" ]),
172
+ Route ("/" , predict , methods = ["POST" ]),
173
+ Route ("/predict" , predict , methods = ["POST" ]),
174
+ Route ("/metrics" , metrics , methods = ["GET" ]),
175
+ ]
176
+ if api_inference_compat ():
177
+ routes .append (
178
+ Route ("/pipeline/{task:path}" , predict , methods = ["POST" ])
179
+ )
152
180
app = Starlette (
153
181
debug = False ,
154
- routes = [
155
- Route ("/" , health , methods = ["GET" ]),
156
- Route ("/health" , health , methods = ["GET" ]),
157
- Route ("/" , predict , methods = ["POST" ]),
158
- Route ("/predict" , predict , methods = ["POST" ]),
159
- Route ("/metrics" , metrics , methods = ["GET" ]),
160
- ],
182
+ routes = routes ,
161
183
on_startup = [prepare_model_artifacts ],
162
184
)
0 commit comments