|
1 | 1 | import os
|
| 2 | +import threading |
2 | 3 | from pathlib import Path
|
3 | 4 | from time import perf_counter
|
4 | 5 |
|
|
16 | 17 | HF_REVISION,
|
17 | 18 | HF_TASK,
|
18 | 19 | )
|
| 20 | +from huggingface_inference_toolkit.env_utils import api_inference_compat |
19 | 21 | from huggingface_inference_toolkit.handler import (
|
20 | 22 | get_inference_handler_either_custom_or_default_handler,
|
21 | 23 | )
|
|
28 | 30 | )
|
29 | 31 | from huggingface_inference_toolkit.vertex_ai_utils import _load_repository_from_gcs
|
30 | 32 |
|
| 33 | +INFERENCE_HANDLERS = {} |
| 34 | +INFERENCE_HANDLERS_LOCK = threading.Lock() |
31 | 35 |
|
32 | 36 | async def prepare_model_artifacts():
|
33 |
| - global inference_handler |
| 37 | + global INFERENCE_HANDLERS |
34 | 38 | # 1. check if model artifacts available in HF_MODEL_DIR
|
35 | 39 | if len(list(Path(HF_MODEL_DIR).glob("**/*"))) <= 0:
|
36 | 40 | # 2. if not available, try to load from HF_MODEL_ID
|
@@ -62,6 +66,7 @@ async def prepare_model_artifacts():
|
62 | 66 | inference_handler = get_inference_handler_either_custom_or_default_handler(
|
63 | 67 | HF_MODEL_DIR, task=HF_TASK
|
64 | 68 | )
|
| 69 | + INFERENCE_HANDLERS[HF_TASK] = inference_handler |
65 | 70 | logger.info("Model initialized successfully")
|
66 | 71 |
|
67 | 72 |
|
@@ -101,6 +106,17 @@ async def predict(request):
|
101 | 106 | dict(request.query_params)
|
102 | 107 | )
|
103 | 108 |
|
| 109 | + # We lazily load pipelines for alt tasks |
| 110 | + task = request.path_params.get("task", HF_TASK) |
| 111 | + inference_handler = INFERENCE_HANDLERS.get(task) |
| 112 | + if not inference_handler: |
| 113 | + global INFERENCE_HANDLERS |
| 114 | + with INFERENCE_HANDLERS_LOCK.acquire(): |
| 115 | + inference_handler = INFERENCE_HANDLERS.get(task) |
| 116 | + if not inference_handler: |
| 117 | + inference_handler = get_inference_handler_either_custom_or_default_handler( |
| 118 | + HF_MODEL_DIR, task=task) |
| 119 | + INFERENCE_HANDLERS[task] = inference_handler |
104 | 120 | # tracks request time
|
105 | 121 | start_time = perf_counter()
|
106 | 122 | # run async not blocking call
|
@@ -149,14 +165,19 @@ async def predict(request):
|
149 | 165 | on_startup=[prepare_model_artifacts],
|
150 | 166 | )
|
151 | 167 | else:
|
| 168 | + routes = [ |
| 169 | + Route("/", health, methods=["GET"]), |
| 170 | + Route("/health", health, methods=["GET"]), |
| 171 | + Route("/", predict, methods=["POST"]), |
| 172 | + Route("/predict", predict, methods=["POST"]), |
| 173 | + Route("/metrics", metrics, methods=["GET"]), |
| 174 | + ] |
| 175 | + if api_inference_compat(): |
| 176 | + routes.append( |
| 177 | + Route("/pipeline/{task:path}", predict, methods=["POST"]) |
| 178 | + ) |
152 | 179 | app = Starlette(
|
153 | 180 | debug=False,
|
154 |
| - routes=[ |
155 |
| - Route("/", health, methods=["GET"]), |
156 |
| - Route("/health", health, methods=["GET"]), |
157 |
| - Route("/", predict, methods=["POST"]), |
158 |
| - Route("/predict", predict, methods=["POST"]), |
159 |
| - Route("/metrics", metrics, methods=["GET"]), |
160 |
| - ], |
| 181 | + routes=routes, |
161 | 182 | on_startup=[prepare_model_artifacts],
|
162 | 183 | )
|
0 commit comments