Skip to content

Commit 7f17bb6

Browse files
committed
Multi task support + /pipeline/<task> support for api-inference backward compat
Signed-off-by: Raphael Glon <[email protected]>
1 parent f0a2df1 commit 7f17bb6

File tree

1 file changed

+30
-8
lines changed

1 file changed

+30
-8
lines changed

src/huggingface_inference_toolkit/webservice_starlette.py

+30-8
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import threading
23
from pathlib import Path
34
from time import perf_counter
45

@@ -16,6 +17,7 @@
1617
HF_REVISION,
1718
HF_TASK,
1819
)
20+
from huggingface_inference_toolkit.env_utils import api_inference_compat
1921
from huggingface_inference_toolkit.handler import (
2022
get_inference_handler_either_custom_or_default_handler,
2123
)
@@ -28,9 +30,11 @@
2830
)
2931
from huggingface_inference_toolkit.vertex_ai_utils import _load_repository_from_gcs
3032

33+
INFERENCE_HANDLERS = {}
34+
INFERENCE_HANDLERS_LOCK = threading.Lock()
3135

3236
async def prepare_model_artifacts():
33-
global inference_handler
37+
global INFERENCE_HANDLERS
3438
# 1. check if model artifacts available in HF_MODEL_DIR
3539
if len(list(Path(HF_MODEL_DIR).glob("**/*"))) <= 0:
3640
# 2. if not available, try to load from HF_MODEL_ID
@@ -62,6 +66,7 @@ async def prepare_model_artifacts():
6266
inference_handler = get_inference_handler_either_custom_or_default_handler(
6367
HF_MODEL_DIR, task=HF_TASK
6468
)
69+
INFERENCE_HANDLERS[HF_TASK] = inference_handler
6570
logger.info("Model initialized successfully")
6671

6772

@@ -82,6 +87,7 @@ async def metrics(request):
8287

8388

8489
async def predict(request):
90+
global INFERENCE_HANDLERS
8591
try:
8692
# extracts content from request
8793
content_type = request.headers.get("content-Type", os.environ.get("DEFAULT_CONTENT_TYPE")).lower()
@@ -101,6 +107,17 @@ async def predict(request):
101107
dict(request.query_params)
102108
)
103109

110+
# We lazily load pipelines for alt tasks
111+
task = request.path_params.get("task", HF_TASK)
112+
inference_handler = INFERENCE_HANDLERS.get(task)
113+
if not inference_handler:
114+
with INFERENCE_HANDLERS_LOCK:
115+
if task not in INFERENCE_HANDLERS:
116+
inference_handler = get_inference_handler_either_custom_or_default_handler(
117+
HF_MODEL_DIR, task=task)
118+
INFERENCE_HANDLERS[task] = inference_handler
119+
else:
120+
inference_handler = INFERENCE_HANDLERS[task]
104121
# tracks request time
105122
start_time = perf_counter()
106123
# run async not blocking call
@@ -149,14 +166,19 @@ async def predict(request):
149166
on_startup=[prepare_model_artifacts],
150167
)
151168
else:
169+
routes = [
170+
Route("/", health, methods=["GET"]),
171+
Route("/health", health, methods=["GET"]),
172+
Route("/", predict, methods=["POST"]),
173+
Route("/predict", predict, methods=["POST"]),
174+
Route("/metrics", metrics, methods=["GET"]),
175+
]
176+
if api_inference_compat():
177+
routes.append(
178+
Route("/pipeline/{task:path}", predict, methods=["POST"])
179+
)
152180
app = Starlette(
153181
debug=False,
154-
routes=[
155-
Route("/", health, methods=["GET"]),
156-
Route("/health", health, methods=["GET"]),
157-
Route("/", predict, methods=["POST"]),
158-
Route("/predict", predict, methods=["POST"]),
159-
Route("/metrics", metrics, methods=["GET"]),
160-
],
182+
routes=routes,
161183
on_startup=[prepare_model_artifacts],
162184
)

0 commit comments

Comments
 (0)