Skip to content

Commit 77035e6

Browse files
committed
Multi task support + /pipeline/<task> support for api-inference backward compat
Signed-off-by: Raphael Glon <[email protected]>
1 parent f0a2df1 commit 77035e6

File tree

1 file changed

+29
-8
lines changed

1 file changed

+29
-8
lines changed

src/huggingface_inference_toolkit/webservice_starlette.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import threading
23
from pathlib import Path
34
from time import perf_counter
45

@@ -16,6 +17,7 @@
1617
HF_REVISION,
1718
HF_TASK,
1819
)
20+
from huggingface_inference_toolkit.env_utils import api_inference_compat
1921
from huggingface_inference_toolkit.handler import (
2022
get_inference_handler_either_custom_or_default_handler,
2123
)
@@ -28,9 +30,11 @@
2830
)
2931
from huggingface_inference_toolkit.vertex_ai_utils import _load_repository_from_gcs
3032

33+
INFERENCE_HANDLERS = {}
34+
INFERENCE_HANDLERS_LOCK = threading.Lock()
3135

3236
async def prepare_model_artifacts():
33-
global inference_handler
37+
global INFERENCE_HANDLERS
3438
# 1. check if model artifacts available in HF_MODEL_DIR
3539
if len(list(Path(HF_MODEL_DIR).glob("**/*"))) <= 0:
3640
# 2. if not available, try to load from HF_MODEL_ID
@@ -62,6 +66,7 @@ async def prepare_model_artifacts():
6266
inference_handler = get_inference_handler_either_custom_or_default_handler(
6367
HF_MODEL_DIR, task=HF_TASK
6468
)
69+
INFERENCE_HANDLERS[HF_TASK] = inference_handler
6570
logger.info("Model initialized successfully")
6671

6772

@@ -101,6 +106,17 @@ async def predict(request):
101106
dict(request.query_params)
102107
)
103108

109+
# We lazily load pipelines for alt tasks
110+
task = request.path_params.get("task", HF_TASK)
111+
inference_handler = INFERENCE_HANDLERS.get(task)
112+
if not inference_handler:
113+
global INFERENCE_HANDLERS
114+
with INFERENCE_HANDLERS_LOCK.acquire():
115+
inference_handler = INFERENCE_HANDLERS.get(task)
116+
if not inference_handler:
117+
inference_handler = get_inference_handler_either_custom_or_default_handler(
118+
HF_MODEL_DIR, task=task)
119+
INFERENCE_HANDLERS[task] = inference_handler
104120
# tracks request time
105121
start_time = perf_counter()
106122
# run async not blocking call
@@ -149,14 +165,19 @@ async def predict(request):
149165
on_startup=[prepare_model_artifacts],
150166
)
151167
else:
168+
routes = [
169+
Route("/", health, methods=["GET"]),
170+
Route("/health", health, methods=["GET"]),
171+
Route("/", predict, methods=["POST"]),
172+
Route("/predict", predict, methods=["POST"]),
173+
Route("/metrics", metrics, methods=["GET"]),
174+
]
175+
if api_inference_compat():
176+
routes.append(
177+
Route("/pipeline/{task:path}", predict, methods=["POST"])
178+
)
152179
app = Starlette(
153180
debug=False,
154-
routes=[
155-
Route("/", health, methods=["GET"]),
156-
Route("/health", health, methods=["GET"]),
157-
Route("/", predict, methods=["POST"]),
158-
Route("/predict", predict, methods=["POST"]),
159-
Route("/metrics", metrics, methods=["GET"]),
160-
],
181+
routes=routes,
161182
on_startup=[prepare_model_artifacts],
162183
)

0 commit comments

Comments
 (0)