55from unittest .mock import Mock , patch
66
77# relative imports
8- from detectors .huggingface .detector import Detector , ContentAnalysisResponse
9- from scheme import ContentAnalysisHttpRequest
8+ from detectors .huggingface .detector import Detector
9+ from detectors . common . scheme import ContentAnalysisResponse , ContentAnalysisHttpRequest
1010
1111
1212@pytest .fixture
@@ -60,58 +60,63 @@ def detector_causal_lm(self):
6060 detector .is_causal_lm = True
6161 detector .is_sequence_classifier = False
6262 detector .risk_names = ["harm" , "bias" ]
63+ detector .function_name = "test_causal_lm"
64+ detector .instruments = {} # Initialize empty instruments dict
6365
6466 return detector
6567
6668 def test_run_sequence_classifier_single_short_input (self , detector_sequence ):
67- request = ContentAnalysisHttpRequest (contents = ["Test content" ])
69+ request = ContentAnalysisHttpRequest (contents = ["Test content" ], detector_params = None )
6870 results = detector_sequence .run (request )
6971
7072 assert len (results ) == 1
7173 assert isinstance (results [0 ][0 ], ContentAnalysisResponse )
72- assert results [0 ][0 ].detection_type == "sequence_classification"
74+ # detection_type is the label from the model (e.g., "LABEL_1", not "sequence_classification")
75+ assert results [0 ][0 ].detection_type in detector_sequence .model .config .id2label .values ()
7376
7477 def test_run_sequence_classifier_single_long_input (self , detector_sequence ):
7578 request = ContentAnalysisHttpRequest (
7679 contents = [
7780 "This is a long content. " * 1_000 ,
78- ]
81+ ],
82+ detector_params = None
7983 )
8084 results = detector_sequence .run (request )
8185
8286 assert len (results ) == 1
8387 assert isinstance (results [0 ][0 ], ContentAnalysisResponse )
84- assert results [0 ][0 ].detection_type == "sequence_classification"
88+ assert results [0 ][0 ].detection_type in detector_sequence . model . config . id2label . values ()
8589
8690 def test_run_sequence_classifier_empty_input (self , detector_sequence ):
87- request = ContentAnalysisHttpRequest (contents = ["" ])
91+ request = ContentAnalysisHttpRequest (contents = ["" ], detector_params = None )
8892 results = detector_sequence .run (request )
8993
9094 assert len (results ) == 1
9195 assert isinstance (results [0 ][0 ], ContentAnalysisResponse )
92- assert results [0 ][0 ].detection_type == "sequence_classification"
96+ assert results [0 ][0 ].detection_type in detector_sequence . model . config . id2label . values ()
9397
9498 def test_run_sequence_classifier_multiple_contents (self , detector_sequence ):
95- request = ContentAnalysisHttpRequest (contents = ["Content 1" , "Content 2" ])
99+ request = ContentAnalysisHttpRequest (contents = ["Content 1" , "Content 2" ], detector_params = None )
96100 results = detector_sequence .run (request )
97101
98102 assert len (results ) == 2
99103 for content_analysis in results :
100104 assert len (content_analysis ) == 1
101105 assert isinstance (content_analysis [0 ], ContentAnalysisResponse )
102- assert content_analysis [0 ].detection_type == "sequence_classification"
106+ assert content_analysis [0 ].detection_type in detector_sequence . model . config . id2label . values ()
103107
104108 def test_run_unsupported_model (self ):
105109 detector = Detector .__new__ (Detector )
106110 detector .is_causal_lm = False
107111 detector .is_sequence_classifier = False
112+ detector .function_name = "test_detector"
108113
109- request = ContentAnalysisHttpRequest (contents = ["Test content" ])
114+ request = ContentAnalysisHttpRequest (contents = ["Test content" ], detector_params = None )
110115 with pytest .raises (ValueError , match = "Unsupported model type for analysis" ):
111116 detector .run (request )
112117
113118 def test_run_causal_lm_single_short_input (self , detector_causal_lm ):
114- request = ContentAnalysisHttpRequest (contents = ["Test content" ])
119+ request = ContentAnalysisHttpRequest (contents = ["Test content" ], detector_params = None )
115120 results = detector_causal_lm .run (request )
116121
117122 assert len (results ) == 1
@@ -122,7 +127,8 @@ def test_run_causal_lm_single_long_input(self, detector_causal_lm):
122127 request = ContentAnalysisHttpRequest (
123128 contents = [
124129 "This is a long content. " * 1_000 ,
125- ]
130+ ],
131+ detector_params = None
126132 )
127133 results = detector_causal_lm .run (request )
128134
@@ -131,15 +137,15 @@ def test_run_causal_lm_single_long_input(self, detector_causal_lm):
131137 assert results [0 ][0 ].detection_type == "causal_lm"
132138
133139 def test_run_causal_lm_empty_input (self , detector_causal_lm ):
134- request = ContentAnalysisHttpRequest (contents = ["" ])
140+ request = ContentAnalysisHttpRequest (contents = ["" ], detector_params = None )
135141 results = detector_causal_lm .run (request )
136142
137143 assert len (results ) == 1
138144 assert isinstance (results [0 ][0 ], ContentAnalysisResponse )
139145 assert results [0 ][0 ].detection_type == "causal_lm"
140146
141147 def tes_run_causal_lm_multiple_contents (self , detector_causal_lm ):
142- request = ContentAnalysisHttpRequest (contents = ["Content 1" , "Content 2" ])
148+ request = ContentAnalysisHttpRequest (contents = ["Content 1" , "Content 2" ], detector_params = None )
143149 results = detector_causal_lm .run (request )
144150
145151 assert len (results ) == 2
0 commit comments