11import os
2- import sys
32
4- sys .path .insert (0 , os .path .abspath (".." ))
3+ from detectors .common .instrumented_detector import InstrumentedDetector
4+
55import json
66import math
77import torch
1111 AutoModelForSequenceClassification ,
1212 AutoModelForCausalLM ,
1313)
14- from common .app import logger
15- from common .scheme import (
14+ from detectors . common .app import logger
15+ from detectors . common .scheme import (
1616 ContentAnalysisHttpRequest ,
1717 ContentAnalysisResponse ,
1818 ContentsAnalysisResponse ,
1919)
2020import gc
2121
22+
2223def _parse_safe_labels_env ():
2324 if os .environ .get ("SAFE_LABELS" ):
2425 try :
@@ -35,7 +36,8 @@ def _parse_safe_labels_env():
3536 logger .info ("SAFE_LABELS env var not set: defaulting to [0]." )
3637 return [0 ]
3738
38- class Detector :
39+
40+ class Detector (InstrumentedDetector ):
3941 risk_names = [
4042 "harm" ,
4143 "social_bias" ,
@@ -50,6 +52,7 @@ def __init__(self):
5052 """
5153 Initialize the Detector class by setting up the model, tokenizer, and device.
5254 """
55+ super ().__init__ ()
5356 self .tokenizer = None
5457 self .model = None
5558 self .cuda_device = None
@@ -102,6 +105,18 @@ def initialize_model(self, model_files_path):
102105 else :
103106 self .model_name = "unknown"
104107
108+ self .registry_name = self .model_name
109+
110+ # set by k8s to be the pod name
111+ if os .environ .get ("DETECTOR_NAME" ):
112+ pod_name = os .environ .get ("DETECTOR_NAME" )
113+ if "-predictor" in pod_name :
114+ # recover the original ISVC name as specified by the user
115+ pod_name = pod_name .split ("-predictor" )[0 ]
116+ self .function_name = pod_name
117+ else :
118+ self .function_name = os .path .basename (model_files_path )
119+
105120 logger .info (f"Model type detected: { self .model_name } " )
106121
107122 def initialize_device (self ):
@@ -173,7 +188,7 @@ def get_probabilities(self, logprobs, safe_token, unsafe_token):
173188 unsafe_token_prob = 1e-50
174189 for gen_token_i in logprobs :
175190 for logprob , index in zip (
176- gen_token_i .values .tolist ()[0 ], gen_token_i .indices .tolist ()[0 ]
191+ gen_token_i .values .tolist ()[0 ], gen_token_i .indices .tolist ()[0 ]
177192 ):
178193 decoded_token = self .tokenizer .convert_ids_to_tokens (index )
179194 if decoded_token .strip ().lower () == safe_token .lower ():
@@ -226,7 +241,7 @@ def process_sequence_classification(self, text, detector_params=None, threshold=
226241 threshold = detector_params .get ("threshold" , 0.5 )
227242 # Merge safe_labels from env and request
228243 request_safe_labels = set (detector_params .get ("safe_labels" , []))
229- all_safe_labels = set (self .safe_labels ) | request_safe_labels
244+ all_safe_labels = set (self .safe_labels ) | request_safe_labels
230245 content_analyses = []
231246 tokenized = self .tokenizer (
232247 text ,
@@ -245,9 +260,9 @@ def process_sequence_classification(self, text, detector_params=None, threshold=
245260 label = self .model .config .id2label [idx ]
246261 # Exclude by index or label name
247262 if (
248- prob >= threshold
249- and idx not in all_safe_labels
250- and label not in all_safe_labels
263+ prob >= threshold
264+ and idx not in all_safe_labels
265+ and label not in all_safe_labels
251266 ):
252267 detection_value = getattr (self .model .config , "problem_type" , None )
253268 content_analyses .append (
@@ -274,16 +289,19 @@ def run(self, input: ContentAnalysisHttpRequest) -> ContentsAnalysisResponse:
274289 ContentsAnalysisResponse: The aggregated response for all input texts.
275290 """
276291 contents_analyses = []
277- for text in input .contents :
278- if self .is_causal_lm :
279- analyses = self .process_causal_lm (text )
280- elif self .is_sequence_classifier :
281- analyses = self .process_sequence_classification (
282- text , detector_params = getattr (input , "detector_params" , None )
283- )
284- else :
285- raise ValueError ("Unsupported model type for analysis." )
286- contents_analyses .append (analyses )
292+ with self .instrument_runtime (self .function_name ):
293+ for text in input .contents :
294+ if self .is_causal_lm :
295+ analyses = self .process_causal_lm (text )
296+ elif self .is_sequence_classifier :
297+ analyses = self .process_sequence_classification (
298+ text , detector_params = getattr (input , "detector_params" , None )
299+ )
300+ else :
301+ raise ValueError ("Unsupported model type for analysis." )
302+ contents_analyses .append (analyses )
303+ is_detection = any (len (analyses ) > 0 for analyses in contents_analyses )
304+ self .increment_detector_instruments (self .function_name , is_detection = is_detection )
287305 return contents_analyses
288306
289307 def close (self ) -> None :
@@ -299,4 +317,4 @@ def close(self) -> None:
299317 gc .collect ()
300318
301319 if torch .cuda .is_available ():
302- torch .cuda .empty_cache ()
320+ torch .cuda .empty_cache ()
0 commit comments