Skip to content

Commit 6b267b5

Browse files
authored
Merge pull request #54 from RobGeada/PrometheusHF
Feat: Modularize Prometheus metrics and add them to the HF runtime
2 parents cfe7f83 + 89087c2 commit 6b267b5

File tree

12 files changed

+222
-131
lines changed

12 files changed

+222
-131
lines changed

detectors/Dockerfile.hf

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,12 @@ FROM builder
9696
WORKDIR /app
9797
ARG CACHEBUST=1
9898
RUN echo "$CACHEBUST"
99-
COPY ./common /common
99+
COPY ./common /app/detectors/common
100+
COPY ./huggingface/detector.py /app/detectors/huggingface/
101+
RUN mkdir /common; cp /app/detectors/common/log_conf.yaml /common/
100102
COPY ./huggingface/app.py /app
101-
COPY ./huggingface/detector.py /app
102103

103104
EXPOSE 8000
104-
CMD ["uvicorn", "app:app", "--workers", "4", "--host", "0.0.0.0", "--port", "8000", "--log-config", "/common/log_conf.yaml"]
105+
CMD ["uvicorn", "app:app", "--workers", "1", "--host", "0.0.0.0", "--port", "8000", "--log-config", "/common/log_conf.yaml"]
105106

106107
# gunicorn main:app --workers 4 --worker-class uvicorn.workers.UvicornWorker --bind 0.0.0.0:8000

detectors/built_in/base_detector_registry.py

Lines changed: 4 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
1-
import contextlib
21
import logging
32
from abc import ABC, abstractmethod
4-
import time
5-
from http.client import HTTPException
3+
from fastapi import HTTPException
64
from typing import List
75

6+
from detectors.common.instrumented_detector import InstrumentedDetector
87
from detectors.common.scheme import ContentAnalysisResponse
98

10-
class BaseDetectorRegistry(ABC):
9+
class BaseDetectorRegistry(InstrumentedDetector, ABC):
1110
def __init__(self, registry_name):
11+
super().__init__(registry_name)
1212
self.registry = None
13-
self.registry_name = registry_name
1413

15-
# prometheus
16-
self.instruments = {}
1714

1815
@abstractmethod
1916
def handle_request(self, content: str, detector_params: dict, headers: dict, **kwargs) -> List[ContentAnalysisResponse]:
@@ -22,35 +19,6 @@ def handle_request(self, content: str, detector_params: dict, headers: dict, **k
2219
def get_registry(self):
2320
return self.registry
2421

25-
def add_instruments(self, gauges):
26-
self.instruments = gauges
27-
28-
def increment_detector_instruments(self, function_name: str, is_detection: bool):
29-
"""Increment the detection and request counters, automatically update rates"""
30-
if self.instruments.get("requests"):
31-
self.instruments["requests"].labels(self.registry_name, function_name).inc()
32-
33-
# The labels() function will initialize the counters if not already created.
34-
# This prevents the counters not existing until they are first incremented
35-
# If the counters have already been created, this is just a cheap dict.get() call
36-
if self.instruments.get("errors"):
37-
_ = self.instruments["errors"].labels(self.registry_name, function_name)
38-
if self.instruments.get("runtime"):
39-
_ = self.instruments["runtime"].labels(self.registry_name, function_name)
40-
41-
# create and/or increment the detection counter
42-
if self.instruments.get("detections"):
43-
detection_counter = self.instruments["detections"].labels(self.registry_name, function_name)
44-
if is_detection:
45-
detection_counter.inc()
46-
47-
48-
def increment_error_instruments(self, function_name: str):
49-
"""Increment the error counter, update the rate gauges"""
50-
if self.instruments.get("errors"):
51-
self.instruments["errors"].labels(self.registry_name, function_name).inc()
52-
53-
5422
def throw_internal_detector_error(self, function_name: str, logger: logging.Logger, exception: Exception, increment_requests: bool):
5523
"""consistent handling of internal errors within a detection function"""
5624
if increment_requests and self.instruments.get("requests"):
@@ -60,16 +28,6 @@ def throw_internal_detector_error(self, function_name: str, logger: logging.Logg
6028
raise HTTPException(status_code=500, detail="Detection error, check detector logs")
6129

6230

63-
@contextlib.contextmanager
64-
def instrument_runtime(self, function_name: str):
65-
try:
66-
start_time = time.time()
67-
yield
68-
if self.instruments.get("runtime"):
69-
self.instruments["runtime"].labels(self.registry_name, function_name).inc(time.time() - start_time)
70-
finally:
71-
pass
72-
7331
def get_detection_functions_from_params(self, params: dict):
7432
"""Parse the request parameters to extract and normalize detection functions as iterable list"""
7533
if self.registry_name in params and isinstance(params[self.registry_name], (list, str)):

detectors/common/app.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,12 @@
77
import yaml
88
from fastapi.exceptions import RequestValidationError
99
from fastapi.responses import JSONResponse
10-
from prometheus_client import Gauge, Counter
11-
12-
sys.path.insert(0, os.path.abspath(".."))
10+
from prometheus_client import Counter
1311

1412
import logging
1513

1614
from fastapi import FastAPI, status
1715
from starlette.exceptions import HTTPException as StarletteHTTPException
18-
from prometheus_fastapi_instrumentator import Instrumentator
1916

2017
logger = logging.getLogger(__name__)
2118
uvicorn_error_logger = logging.getLogger("uvicorn.error")
@@ -39,22 +36,22 @@ def __init__(self, *args, **kwargs):
3936
self.state.instruments = {
4037
"detections": Counter(
4138
"trustyai_guardrails_detections",
42-
"Number of detections per built-in detector function",
39+
"Number of detections per detector function",
4340
["detector_kind", "detector_name"]
4441
),
4542
"requests": Counter(
4643
"trustyai_guardrails_requests",
47-
"Number of requests per built-in detector function",
44+
"Number of requests per detector function",
4845
["detector_kind", "detector_name"]
4946
),
5047
"errors": Counter(
5148
"trustyai_guardrails_errors",
52-
"Number of errors per built-in detector function",
49+
"Number of errors per detector function",
5350
["detector_kind", "detector_name"]
5451
),
5552
"runtime": Counter(
5653
"trustyai_guardrails_runtime",
57-
"Total runtime of a built-in detector function- this is the induced latency of this guardrail",
54+
"Total runtime of a detector function- this is the induced latency of this guardrail",
5855
["detector_kind", "detector_name"]
5956
)
6057
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import contextlib
2+
import time
3+
4+
5+
class InstrumentedDetector:
6+
def __init__(self, registry_name: str = "default"):
7+
self.registry_name = registry_name
8+
self.instruments = {}
9+
10+
@contextlib.contextmanager
11+
def instrument_runtime(self, function_name: str):
12+
try:
13+
start_time = time.time()
14+
yield
15+
if self.instruments.get("runtime"):
16+
self.instruments["runtime"].labels(self.registry_name, function_name).inc(time.time() - start_time)
17+
finally:
18+
pass
19+
20+
def add_instruments(self, gauges):
21+
self.instruments = gauges
22+
23+
def increment_detector_instruments(self, function_name: str, is_detection: bool):
24+
"""Increment the detection and request counters, automatically update rates"""
25+
if self.instruments.get("requests"):
26+
self.instruments["requests"].labels(self.registry_name, function_name).inc()
27+
28+
# The labels() function will initialize the counters if not already created.
29+
# This prevents the counters not existing until they are first incremented
30+
# If the counters have already been created, this is just a cheap dict.get() call
31+
if self.instruments.get("errors"):
32+
_ = self.instruments["errors"].labels(self.registry_name, function_name)
33+
if self.instruments.get("runtime"):
34+
_ = self.instruments["runtime"].labels(self.registry_name, function_name)
35+
36+
# create and/or increment the detection counter
37+
if self.instruments.get("detections"):
38+
detection_counter = self.instruments["detections"].labels(self.registry_name, function_name)
39+
if is_detection:
40+
detection_counter.inc()
41+
42+
def increment_error_instruments(self, function_name: str):
43+
"""Increment the error counter, update the rate gauges"""
44+
if self.instruments.get("errors"):
45+
self.instruments["errors"].labels(self.registry_name, function_name).inc()

detectors/common/requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ locust==2.31.1
33
pre-commit==3.8.0
44
pytest==8.3.2
55
tls-test-tools
6+
protobuf==6.33.0

detectors/huggingface/app.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
1-
import os
2-
import sys
31
from contextlib import asynccontextmanager
4-
from typing import Annotated
2+
from typing import List
53

6-
from fastapi import Header
74
from prometheus_fastapi_instrumentator import Instrumentator
8-
sys.path.insert(0, os.path.abspath(".."))
9-
10-
from common.app import DetectorBaseAPI as FastAPI
11-
from detector import Detector
12-
from common.scheme import (
5+
from starlette.concurrency import run_in_threadpool
6+
from detectors.common.app import DetectorBaseAPI as FastAPI
7+
from detectors.huggingface.detector import Detector
8+
from detectors.common.scheme import (
139
ContentAnalysisHttpRequest,
1410
ContentsAnalysisResponse,
1511
Error,
@@ -18,7 +14,9 @@
1814

1915
@asynccontextmanager
2016
async def lifespan(app: FastAPI):
21-
app.set_detector(Detector())
17+
detector = Detector()
18+
app.set_detector(detector, detector.model_name)
19+
detector.add_instruments(app.state.instruments)
2220
yield
2321
# Clean up the ML models and release the resources
2422
detector: Detector = app.get_detector()
@@ -42,10 +40,11 @@ async def lifespan(app: FastAPI):
4240
},
4341
)
4442
async def detector_unary_handler(
45-
request: ContentAnalysisHttpRequest,
46-
detector_id: Annotated[str, Header(example="en_syntax_slate.38m.hap")],
43+
request: ContentAnalysisHttpRequest,
4744
):
48-
detector: Detector = app.get_detector()
49-
if not detector:
45+
detectors: List[Detector] = list(app.get_all_detectors().values())
46+
if not len(detectors) or not detectors[0]:
5047
raise RuntimeError("Detector is not initialized")
51-
return ContentsAnalysisResponse(root=detector.run(request))
48+
result = await run_in_threadpool(detectors[0].run, request)
49+
return ContentsAnalysisResponse(root=result)
50+

detectors/huggingface/deploy/servingruntime.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ metadata:
99
opendatahub.io/dashboard: 'true'
1010
spec:
1111
annotations:
12-
prometheus.io/port: '8080'
12+
prometheus.io/port: '8000'
1313
prometheus.io/path: '/metrics'
1414
multiModel: false
1515
supportedModelFormats:
@@ -35,6 +35,10 @@ spec:
3535
value: /mnt/models
3636
- name: HF_HOME
3737
value: /tmp/hf_home
38+
- name: DETECTOR_NAME
39+
valueFrom:
40+
fieldRef:
41+
fieldPath: metadata.name
3842
ports:
3943
- containerPort: 8000
4044
protocol: TCP

detectors/huggingface/detector.py

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
2-
import sys
32

4-
sys.path.insert(0, os.path.abspath(".."))
3+
from detectors.common.instrumented_detector import InstrumentedDetector
4+
55
import json
66
import math
77
import torch
@@ -11,14 +11,15 @@
1111
AutoModelForSequenceClassification,
1212
AutoModelForCausalLM,
1313
)
14-
from common.app import logger
15-
from common.scheme import (
14+
from detectors.common.app import logger
15+
from detectors.common.scheme import (
1616
ContentAnalysisHttpRequest,
1717
ContentAnalysisResponse,
1818
ContentsAnalysisResponse,
1919
)
2020
import gc
2121

22+
2223
def _parse_safe_labels_env():
2324
if os.environ.get("SAFE_LABELS"):
2425
try:
@@ -35,7 +36,8 @@ def _parse_safe_labels_env():
3536
logger.info("SAFE_LABELS env var not set: defaulting to [0].")
3637
return [0]
3738

38-
class Detector:
39+
40+
class Detector(InstrumentedDetector):
3941
risk_names = [
4042
"harm",
4143
"social_bias",
@@ -50,6 +52,7 @@ def __init__(self):
5052
"""
5153
Initialize the Detector class by setting up the model, tokenizer, and device.
5254
"""
55+
super().__init__()
5356
self.tokenizer = None
5457
self.model = None
5558
self.cuda_device = None
@@ -102,6 +105,18 @@ def initialize_model(self, model_files_path):
102105
else:
103106
self.model_name = "unknown"
104107

108+
self.registry_name = self.model_name
109+
110+
# set by k8s to be the pod name
111+
if os.environ.get("DETECTOR_NAME"):
112+
pod_name = os.environ.get("DETECTOR_NAME")
113+
if "-predictor" in pod_name:
114+
# recover the original ISVC name as specified by the user
115+
pod_name = pod_name.split("-predictor")[0]
116+
self.function_name = pod_name
117+
else:
118+
self.function_name = os.path.basename(model_files_path)
119+
105120
logger.info(f"Model type detected: {self.model_name}")
106121

107122
def initialize_device(self):
@@ -173,7 +188,7 @@ def get_probabilities(self, logprobs, safe_token, unsafe_token):
173188
unsafe_token_prob = 1e-50
174189
for gen_token_i in logprobs:
175190
for logprob, index in zip(
176-
gen_token_i.values.tolist()[0], gen_token_i.indices.tolist()[0]
191+
gen_token_i.values.tolist()[0], gen_token_i.indices.tolist()[0]
177192
):
178193
decoded_token = self.tokenizer.convert_ids_to_tokens(index)
179194
if decoded_token.strip().lower() == safe_token.lower():
@@ -226,7 +241,7 @@ def process_sequence_classification(self, text, detector_params=None, threshold=
226241
threshold = detector_params.get("threshold", 0.5)
227242
# Merge safe_labels from env and request
228243
request_safe_labels = set(detector_params.get("safe_labels", []))
229-
all_safe_labels = set(self.safe_labels) | request_safe_labels
244+
all_safe_labels = set(self.safe_labels) | request_safe_labels
230245
content_analyses = []
231246
tokenized = self.tokenizer(
232247
text,
@@ -245,9 +260,9 @@ def process_sequence_classification(self, text, detector_params=None, threshold=
245260
label = self.model.config.id2label[idx]
246261
# Exclude by index or label name
247262
if (
248-
prob >= threshold
249-
and idx not in all_safe_labels
250-
and label not in all_safe_labels
263+
prob >= threshold
264+
and idx not in all_safe_labels
265+
and label not in all_safe_labels
251266
):
252267
detection_value = getattr(self.model.config, "problem_type", None)
253268
content_analyses.append(
@@ -274,16 +289,19 @@ def run(self, input: ContentAnalysisHttpRequest) -> ContentsAnalysisResponse:
274289
ContentsAnalysisResponse: The aggregated response for all input texts.
275290
"""
276291
contents_analyses = []
277-
for text in input.contents:
278-
if self.is_causal_lm:
279-
analyses = self.process_causal_lm(text)
280-
elif self.is_sequence_classifier:
281-
analyses = self.process_sequence_classification(
282-
text, detector_params=getattr(input, "detector_params", None)
283-
)
284-
else:
285-
raise ValueError("Unsupported model type for analysis.")
286-
contents_analyses.append(analyses)
292+
with self.instrument_runtime(self.function_name):
293+
for text in input.contents:
294+
if self.is_causal_lm:
295+
analyses = self.process_causal_lm(text)
296+
elif self.is_sequence_classifier:
297+
analyses = self.process_sequence_classification(
298+
text, detector_params=getattr(input, "detector_params", None)
299+
)
300+
else:
301+
raise ValueError("Unsupported model type for analysis.")
302+
contents_analyses.append(analyses)
303+
is_detection = any(len(analyses) > 0 for analyses in contents_analyses)
304+
self.increment_detector_instruments(self.function_name, is_detection=is_detection)
287305
return contents_analyses
288306

289307
def close(self) -> None:
@@ -299,4 +317,4 @@ def close(self) -> None:
299317
gc.collect()
300318

301319
if torch.cuda.is_available():
302-
torch.cuda.empty_cache()
320+
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)