Skip to content

Commit 3aadc67

Browse files
committed
Refactor get_*_model_name to avoid caching fallback model name
1 parent 8e2d80f commit 3aadc67

File tree

2 files changed

+71
-15
lines changed
  • api/src/nv_ingest_api/internal/primitives/nim/model_interface

2 files changed

+71
-15
lines changed

api/src/nv_ingest_api/internal/primitives/nim/model_interface/ocr.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
import tritonclient.grpc as grpcclient
1717

1818
from nv_ingest_api.internal.primitives.nim import ModelInterface
19-
from nv_ingest_api.internal.primitives.nim.model_interface.decorators import multiprocessing_cache
19+
from nv_ingest_api.internal.primitives.nim.model_interface.decorators import global_cache
20+
from nv_ingest_api.internal.primitives.nim.model_interface.decorators import lock
2021
from nv_ingest_api.internal.primitives.nim.model_interface.helpers import preprocess_image_for_paddle
2122
from nv_ingest_api.util.image_processing.transforms import base64_to_numpy
2223

@@ -752,12 +753,11 @@ def _format_single_batch(
752753
raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
753754

754755

755-
@multiprocessing_cache(max_calls=100) # Cache results first to avoid redundant retries from backoff
756756
@backoff.on_predicate(backoff.expo, max_time=30)
757757
def get_ocr_model_name(ocr_grpc_endpoint=None, default_model_name=DEFAULT_OCR_MODEL_NAME):
758758
"""
759759
Determines the OCR model name by checking the environment, querying the gRPC endpoint,
760-
or falling back to a default.
760+
or falling back to a default. Only caches when the repository is successfully queried.
761761
"""
762762
# 1. Check for an explicit override from the environment variable first.
763763
ocr_model_name = os.getenv("OCR_MODEL_NAME", None)
@@ -769,14 +769,25 @@ def get_ocr_model_name(ocr_grpc_endpoint=None, default_model_name=DEFAULT_OCR_MO
769769
logger.debug(f"No OCR gRPC endpoint provided. Falling back to default model name '{default_model_name}'.")
770770
return default_model_name
771771

772-
# 3. Attempt to query the gRPC endpoint to discover the model name.
772+
# 3. Check cache (only populated on successful repository query).
773+
key = (
774+
"get_ocr_model_name",
775+
(ocr_grpc_endpoint,),
776+
frozenset({"default_model_name": default_model_name}.items()),
777+
)
778+
with lock:
779+
if key in global_cache:
780+
return global_cache[key]
781+
782+
# 4. Attempt to query the gRPC endpoint to discover the model name.
773783
try:
774784
client = grpcclient.InferenceServerClient(ocr_grpc_endpoint)
775785
model_index = client.get_model_repository_index(as_json=True)
776786
model_names = [x["name"] for x in model_index.get("models", [])]
777787
ocr_model_name = model_names[0]
788+
with lock:
789+
global_cache[key] = ocr_model_name
790+
return ocr_model_name
778791
except Exception:
779792
logger.warning(f"Failed to get ocr model name after 30 seconds. Falling back to '{default_model_name}'.")
780-
ocr_model_name = default_model_name
781-
782-
return ocr_model_name
793+
return default_model_name

api/src/nv_ingest_api/internal/primitives/nim/model_interface/yolox.py

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
from nv_ingest_api.internal.primitives.nim import ModelInterface
2222
import tritonclient.grpc as grpcclient
23+
from nv_ingest_api.internal.primitives.nim.model_interface.decorators import global_cache
24+
from nv_ingest_api.internal.primitives.nim.model_interface.decorators import lock
2325
from nv_ingest_api.internal.primitives.nim.model_interface.decorators import multiprocessing_cache
2426
from nv_ingest_api.internal.primitives.nim.model_interface.helpers import get_model_name
2527
from nv_ingest_api.util.image_processing import scale_image_to_encoding_size
@@ -135,10 +137,36 @@ def __init__(
135137
self.class_labels = class_labels
136138

137139
if endpoints:
138-
self.model_name = get_yolox_model_name(endpoints[0], default_model_name="yolox_ensemble")
139-
self._grpc_uses_bls = self.model_name == "pipeline"
140+
self._yolox_grpc_endpoint = endpoints[0]
141+
self._model_name = None
142+
self._grpc_uses_bls_value = None # Resolved on first use
140143
else:
141-
self._grpc_uses_bls = False
144+
self._yolox_grpc_endpoint = None
145+
self._model_name = None
146+
self._grpc_uses_bls_value = False
147+
148+
def _resolve_yolox_model_name_if_needed(self) -> None:
149+
"""Resolve model name and BLS flag from the gRPC endpoint on first use. Cached on the instance."""
150+
if self._yolox_grpc_endpoint is None:
151+
return
152+
if self._model_name is not None:
153+
return
154+
self._model_name = get_yolox_model_name(self._yolox_grpc_endpoint, default_model_name="yolox_ensemble")
155+
self._grpc_uses_bls_value = self._model_name == "pipeline"
156+
157+
@property
158+
def model_name(self) -> Optional[str]:
159+
self._resolve_yolox_model_name_if_needed()
160+
return self._model_name
161+
162+
@model_name.setter
163+
def model_name(self, value: Optional[str]) -> None:
164+
self._model_name = value
165+
166+
@property
167+
def _grpc_uses_bls(self) -> bool:
168+
self._resolve_yolox_model_name_if_needed()
169+
return bool(self._grpc_uses_bls_value)
142170

143171
def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
144172
"""
@@ -2117,7 +2145,6 @@ def postprocess_included_texts(boxes, confs, labels, classes):
21172145
return boxes, labels, confs
21182146

21192147

2120-
@multiprocessing_cache(max_calls=100) # Cache results first to avoid redundant retries from backoff
21212148
@backoff.on_predicate(backoff.expo, max_time=30)
21222149
def get_yolox_model_name(yolox_grpc_endpoint, default_model_name="yolox"):
21232150
# If a gRPC endpoint isn't provided (common when using HTTP-only NIM endpoints),
@@ -2131,6 +2158,15 @@ def get_yolox_model_name(yolox_grpc_endpoint, default_model_name="yolox"):
21312158
):
21322159
return default_model_name
21332160

2161+
key = (
2162+
"get_yolox_model_name",
2163+
(yolox_grpc_endpoint,),
2164+
frozenset({"default_model_name": default_model_name}.items()),
2165+
)
2166+
with lock:
2167+
if key in global_cache:
2168+
return global_cache[key]
2169+
21342170
try:
21352171
client = grpcclient.InferenceServerClient(yolox_grpc_endpoint)
21362172
model_index = client.get_model_repository_index(as_json=True)
@@ -2148,14 +2184,23 @@ def get_yolox_model_name(yolox_grpc_endpoint, default_model_name="yolox"):
21482184
"nemoretriever-page-elements-v2",
21492185
):
21502186
if preferred in model_names:
2151-
return preferred
2187+
result = preferred
2188+
with lock:
2189+
global_cache[key] = result
2190+
return result
21522191

21532192
# Otherwise pick a best-effort match for newer model names.
21542193
candidates = [m for m in model_names if isinstance(m, str) and ("yolox" in m or "page-elements" in m)]
21552194
if candidates:
2156-
return sorted(candidates)[0]
2157-
2158-
return default_model_name
2195+
result = sorted(candidates)[0]
2196+
with lock:
2197+
global_cache[key] = result
2198+
return result
2199+
2200+
result = default_model_name
2201+
with lock:
2202+
global_cache[key] = result
2203+
return result
21592204
except Exception as e:
21602205
logger.warning(
21612206
"Failed to inspect YOLOX model repository at '%s' (%s). Falling back to '%s'.",

0 commit comments

Comments
 (0)