Skip to content

Commit d0d08fa

Browse files
authored
Merge pull request #65 from m-misiura/fix_tests_hf_runtime
Fixing Tier 1 - Hugging Face Runtime unit tests
2 parents 2190e36 + 81f7219 commit d0d08fa

File tree

7 files changed

+82
-81
lines changed

7 files changed

+82
-81
lines changed

.github/workflows/test-huggingface-runtime.yaml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,16 @@ jobs:
7474
HF_HOME: /tmp/huggingface
7575
TRANSFORMERS_CACHE: /tmp/transformers_cache
7676
TOKENIZERS_PARALLELISM: false
77+
MODEL_DIR: tests/dummy_models/bert/BertForSequenceClassification
7778
run: |
7879
python -c "
7980
try:
80-
from detectors.huggingface.detector import HuggingFaceDetector
81-
print('HuggingFaceDetector import successful')
81+
from detectors.huggingface.detector import Detector
82+
print('Detector import successful')
8283
8384
# Test basic initialization
84-
detector = HuggingFaceDetector()
85-
print('HuggingFaceDetector initialization successful')
85+
detector = Detector()
86+
print('Detector initialization successful')
8687
except Exception as e:
8788
print(f'Error testing HF detector: {e}')
8889
exit(1)

detectors/common/requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ pre-commit==3.8.0
44
pytest==8.3.2
55
tls-test-tools
66
protobuf==6.33.0
7+
torch==2.9.0

tests/detectors/huggingface/test_method_initialize_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44

55
# local imports
6-
from detectors.huggingface.scheme import ContentAnalysisResponse
6+
from detectors.common.scheme import ContentAnalysisResponse
77
from detectors.huggingface.detector import Detector
88

99

tests/detectors/huggingface/test_method_process_causal_lm.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,9 @@ def validate_results(self, results, input_text, detector):
6060
"detection",
6161
"detection_type",
6262
"score",
63-
"sequence_classification",
64-
"sequence_probability",
65-
"token_classifications",
66-
"token_probabilities",
6763
"text",
6864
"evidences",
65+
"metadata",
6966
]
7067

7168
for field in expected_fields:
@@ -79,16 +76,12 @@ def validate_results(self, results, input_text, detector):
7976
assert isinstance(result.detection, str)
8077
assert isinstance(result.detection_type, str)
8178
assert isinstance(result.score, float)
82-
assert isinstance(result.sequence_classification, str)
83-
assert isinstance(result.sequence_probability, float)
8479
assert isinstance(result.text, str)
8580
assert isinstance(result.evidences, list)
8681

8782
assert 0 <= result.start <= len(input_text)
8883
assert 0 <= result.end <= len(input_text)
8984
assert 0.0 <= result.score <= 1.0
90-
assert 0.0 <= result.sequence_probability <= 1.0
91-
assert result.sequence_classification in detector.risk_names
9285

9386
def test_process_causal_lm_single_short_input(self, detector_instance):
9487
text = "This is a test."

tests/detectors/huggingface/test_method_process_sequence_classification.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,9 @@ def validate_results(self, results, input_text):
4040
"detection",
4141
"detection_type",
4242
"score",
43-
"sequence_classification",
44-
"sequence_probability",
45-
"token_classifications",
46-
"token_probabilities",
4743
"text",
4844
"evidences",
45+
"metadata",
4946
]
5047

5148
for field in expected_fields:
@@ -59,12 +56,6 @@ def validate_results(self, results, input_text):
5956
assert isinstance(result.detection, str), "detection should be string"
6057
assert isinstance(result.detection_type, str), "detection_type should be string"
6158
assert isinstance(result.score, float), "score should be float"
62-
assert isinstance(
63-
result.sequence_classification, str
64-
), "sequence_classification should be string"
65-
assert isinstance(
66-
result.sequence_probability, float
67-
), "sequence_probability should be float"
6859
assert isinstance(result.text, str), "text should be string"
6960
assert isinstance(result.evidences, list), "evidences should be list"
7061

@@ -73,9 +64,6 @@ def validate_results(self, results, input_text):
7364
), "start should be within text bounds"
7465
assert 0 <= result.end <= len(input_text), "end should be within text bounds"
7566
assert 0.0 <= result.score <= 1.0, "score should be between 0 and 1"
76-
assert (
77-
0.0 <= result.sequence_probability <= 1.0
78-
), "sequence_probability should be between 0 and 1"
7967

8068
return result
8169

tests/detectors/huggingface/test_method_run.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from unittest.mock import Mock, patch
66

77
# relative imports
8-
from detectors.huggingface.detector import Detector, ContentAnalysisResponse
9-
from scheme import ContentAnalysisHttpRequest
8+
from detectors.huggingface.detector import Detector
9+
from detectors.common.scheme import ContentAnalysisResponse, ContentAnalysisHttpRequest
1010

1111

1212
@pytest.fixture
@@ -60,58 +60,63 @@ def detector_causal_lm(self):
6060
detector.is_causal_lm = True
6161
detector.is_sequence_classifier = False
6262
detector.risk_names = ["harm", "bias"]
63+
detector.function_name = "test_causal_lm"
64+
detector.instruments = {} # Initialize empty instruments dict
6365

6466
return detector
6567

6668
def test_run_sequence_classifier_single_short_input(self, detector_sequence):
67-
request = ContentAnalysisHttpRequest(contents=["Test content"])
69+
request = ContentAnalysisHttpRequest(contents=["Test content"], detector_params=None)
6870
results = detector_sequence.run(request)
6971

7072
assert len(results) == 1
7173
assert isinstance(results[0][0], ContentAnalysisResponse)
72-
assert results[0][0].detection_type == "sequence_classification"
74+
# detection_type is the label from the model (e.g., "LABEL_1", not "sequence_classification")
75+
assert results[0][0].detection_type in detector_sequence.model.config.id2label.values()
7376

7477
def test_run_sequence_classifier_single_long_input(self, detector_sequence):
7578
request = ContentAnalysisHttpRequest(
7679
contents=[
7780
"This is a long content. " * 1_000,
78-
]
81+
],
82+
detector_params=None
7983
)
8084
results = detector_sequence.run(request)
8185

8286
assert len(results) == 1
8387
assert isinstance(results[0][0], ContentAnalysisResponse)
84-
assert results[0][0].detection_type == "sequence_classification"
88+
assert results[0][0].detection_type in detector_sequence.model.config.id2label.values()
8589

8690
def test_run_sequence_classifier_empty_input(self, detector_sequence):
87-
request = ContentAnalysisHttpRequest(contents=[""])
91+
request = ContentAnalysisHttpRequest(contents=[""], detector_params=None)
8892
results = detector_sequence.run(request)
8993

9094
assert len(results) == 1
9195
assert isinstance(results[0][0], ContentAnalysisResponse)
92-
assert results[0][0].detection_type == "sequence_classification"
96+
assert results[0][0].detection_type in detector_sequence.model.config.id2label.values()
9397

9498
def test_run_sequence_classifier_multiple_contents(self, detector_sequence):
95-
request = ContentAnalysisHttpRequest(contents=["Content 1", "Content 2"])
99+
request = ContentAnalysisHttpRequest(contents=["Content 1", "Content 2"], detector_params=None)
96100
results = detector_sequence.run(request)
97101

98102
assert len(results) == 2
99103
for content_analysis in results:
100104
assert len(content_analysis) == 1
101105
assert isinstance(content_analysis[0], ContentAnalysisResponse)
102-
assert content_analysis[0].detection_type == "sequence_classification"
106+
assert content_analysis[0].detection_type in detector_sequence.model.config.id2label.values()
103107

104108
def test_run_unsupported_model(self):
105109
detector = Detector.__new__(Detector)
106110
detector.is_causal_lm = False
107111
detector.is_sequence_classifier = False
112+
detector.function_name = "test_detector"
108113

109-
request = ContentAnalysisHttpRequest(contents=["Test content"])
114+
request = ContentAnalysisHttpRequest(contents=["Test content"], detector_params=None)
110115
with pytest.raises(ValueError, match="Unsupported model type for analysis"):
111116
detector.run(request)
112117

113118
def test_run_causal_lm_single_short_input(self, detector_causal_lm):
114-
request = ContentAnalysisHttpRequest(contents=["Test content"])
119+
request = ContentAnalysisHttpRequest(contents=["Test content"], detector_params=None)
115120
results = detector_causal_lm.run(request)
116121

117122
assert len(results) == 1
@@ -122,7 +127,8 @@ def test_run_causal_lm_single_long_input(self, detector_causal_lm):
122127
request = ContentAnalysisHttpRequest(
123128
contents=[
124129
"This is a long content. " * 1_000,
125-
]
130+
],
131+
detector_params=None
126132
)
127133
results = detector_causal_lm.run(request)
128134

@@ -131,15 +137,15 @@ def test_run_causal_lm_single_long_input(self, detector_causal_lm):
131137
assert results[0][0].detection_type == "causal_lm"
132138

133139
def test_run_causal_lm_empty_input(self, detector_causal_lm):
134-
request = ContentAnalysisHttpRequest(contents=[""])
140+
request = ContentAnalysisHttpRequest(contents=[""], detector_params=None)
135141
results = detector_causal_lm.run(request)
136142

137143
assert len(results) == 1
138144
assert isinstance(results[0][0], ContentAnalysisResponse)
139145
assert results[0][0].detection_type == "causal_lm"
140146

141147
def tes_run_causal_lm_multiple_contents(self, detector_causal_lm):
142-
request = ContentAnalysisHttpRequest(contents=["Content 1", "Content 2"])
148+
request = ContentAnalysisHttpRequest(contents=["Content 1", "Content 2"], detector_params=None)
143149
results = detector_causal_lm.run(request)
144150

145151
assert len(results) == 2

tests/detectors/huggingface/test_metrics.py

Lines changed: 52 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest
66
import torch
77
from starlette.testclient import TestClient
8+
from prometheus_client import REGISTRY
89

910
# DO NOT IMPORT THIS VALUE, if we import common.app before the test fixtures we can break prometheus multiprocessing
1011
METRIC_PREFIX = "trustyai_guardrails"
@@ -25,8 +26,10 @@ def send_request(client: TestClient, detect: bool, slow: bool = False):
2526

2627

2728
def get_metric_dict(client: TestClient):
28-
metrics = client.get("/metrics")
29-
metrics = metrics.content.decode().split("\n")
29+
# In test mode with TestClient, we're running in a single process,
30+
# so multiprocess mode doesn't work. Use the default REGISTRY directly.
31+
from prometheus_client import generate_latest, REGISTRY
32+
metrics = generate_latest(REGISTRY).decode().split("\n")
3033
metric_dict = {}
3134

3235
for m in metrics:
@@ -36,45 +39,54 @@ def get_metric_dict(client: TestClient):
3639

3740
return metric_dict
3841

42+
@pytest.fixture(scope="session")
43+
def client(prometheus_multiproc_dir):
44+
# Clear any existing metrics from the REGISTRY before importing the app
45+
# This is needed because even in multiprocess mode, metrics are registered to REGISTRY
46+
collectors_to_unregister = [
47+
c for c in list(REGISTRY._collector_to_names.keys())
48+
if hasattr(c, '_name') and 'trustyai_guardrails' in c._name
49+
]
50+
for collector in collectors_to_unregister:
51+
try:
52+
REGISTRY.unregister(collector)
53+
except Exception:
54+
pass
55+
56+
current_dir = os.path.dirname(__file__)
57+
parent_dir = os.path.dirname(os.path.dirname(current_dir))
58+
os.environ["MODEL_DIR"] = os.path.join(parent_dir, "dummy_models", "bert/BertForSequenceClassification")
59+
60+
from detectors.huggingface.app import app
61+
from detectors.huggingface.detector import Detector
62+
detector = Detector()
63+
64+
# patch the model to allow for control over detections - long messages will flag
65+
def detection_fn(*args, **kwargs):
66+
output = Mock()
67+
if kwargs["input_ids"].shape[-1] > 10:
68+
output.logits = torch.tensor([[0.0, 1.0]])
69+
else:
70+
output.logits = torch.tensor([[1.0, 0.0]])
71+
72+
if kwargs["input_ids"].shape[-1] > 100:
73+
time.sleep(.25)
74+
return output
75+
76+
class ModelMock:
77+
def __init__(self):
78+
self.config = Mock()
79+
self.config.id2label = detector.model.config.id2label
80+
self.config.problem_type = detector.model.config.problem_type
81+
def __call__(self, *args, **kwargs):
82+
return detection_fn(*args, **kwargs)
83+
84+
detector.model = ModelMock()
85+
app.set_detector(detector, detector.registry_name)
86+
detector.set_instruments(app.state.instruments)
87+
return TestClient(app)
88+
3989
class TestMetrics:
40-
@pytest.fixture
41-
def client(self):
42-
current_dir = os.path.dirname(__file__)
43-
parent_dir = os.path.dirname(os.path.dirname(current_dir))
44-
os.environ["MODEL_DIR"] = os.path.join(parent_dir, "dummy_models", "bert/BertForSequenceClassification")
45-
46-
from detectors.huggingface.app import app
47-
# clear the metric registry at the start of each test, but AFTER the multiprocessing metrics is set up
48-
import prometheus_client
49-
prometheus_client.REGISTRY._names_to_collectors.clear()
50-
51-
from detectors.huggingface.detector import Detector
52-
detector = Detector()
53-
54-
# patch the model to allow for control over detections - long messages will flag
55-
def detection_fn(*args, **kwargs):
56-
output = Mock()
57-
if kwargs["input_ids"].shape[-1] > 10:
58-
output.logits = torch.tensor([[0.0, 1.0]])
59-
else:
60-
output.logits = torch.tensor([[1.0, 0.0]])
61-
62-
if kwargs["input_ids"].shape[-1] > 100:
63-
time.sleep(.25)
64-
return output
65-
66-
class ModelMock:
67-
def __init__(self):
68-
self.config = Mock()
69-
self.config.id2label = detector.model.config.id2label
70-
self.config.problem_type = detector.model.config.problem_type
71-
def __call__(self, *args, **kwargs):
72-
return detection_fn(*args, **kwargs)
73-
74-
detector.model = ModelMock()
75-
app.set_detector(detector, detector.registry_name)
76-
detector.set_instruments(app.state.instruments)
77-
return TestClient(app)
7890

7991

8092

0 commit comments

Comments
 (0)