33Provides common functionality to eliminate code duplication across categories
44"""
55
6+ import time
67from abc import ABC , abstractmethod
78from dataclasses import dataclass , field
89from typing import Any
910
1011from src .utils .evaluator import EvaluationResult , ResponseEvaluator , VulnerabilityCategory
11- from src .utils .live_display import get_display
12+ from src .utils .live_display import TestProgress , get_display
1213from src .utils .model_client import ModelResponse , OllamaClient
1314
1415
@@ -56,34 +57,57 @@ def get_vulnerability_category(self) -> VulnerabilityCategory:
5657 pass
5758
5859 def run_single_test (
59- self , test : BaseTest , current_num : int = 1 , total_tests : int = 1
60+ self ,
61+ test : BaseTest ,
62+ current_num : int = 1 ,
63+ total_tests : int = 1 ,
64+ current_rep : int = 1 ,
65+ total_reps : int = 1 ,
66+ show_prompts : bool = True ,
6067 ) -> tuple [BaseTest , list [ModelResponse ], EvaluationResult ]:
6168 """Run a single test with standardized error handling and display"""
6269 display = get_display ()
6370
64- # Start test display
65- progress = display .start_test (
66- test .name , test .test_id , test .category , current_num , total_tests
67- )
71+ # Show test header only on first repetition
72+ progress = None
73+ if current_rep == 1 :
74+ progress = display .start_test (
75+ test .name , test .test_id , test .category , current_num , total_tests
76+ )
77+ else :
78+ # Create minimal progress object for repetitions
79+ progress = TestProgress (
80+ test_name = test .name ,
81+ test_id = test .test_id ,
82+ category = test .category ,
83+ start_time = time .time (),
84+ current_test_num = current_num ,
85+ total_tests = total_tests ,
86+ )
87+
88+ # Show repetition header for multi-repetition runs
89+ display .show_repetition_header (current_rep , total_reps )
6890
6991 responses = []
7092
7193 try :
7294 if test .follow_up_prompts and len (test .follow_up_prompts ) > 0 :
7395 # Multi-turn conversation
74- display .show_prompt (progress , test .prompt , test .system_prompt )
75- display .start_thinking_timer (progress )
96+ if show_prompts and current_rep == 1 :
97+ # Show all prompts at once for multi-turn on first repetition
98+ all_prompts = [test .prompt ] + test .follow_up_prompts
99+ display .show_multi_turn_prompts (all_prompts , test .system_prompt )
76100
101+ # Execute conversation turns
102+ display .start_thinking_timer (progress )
77103 response = self .client .generate (test .prompt , test .system_prompt )
78104 display .stop_thinking_timer ()
79105 responses .append (response )
80106 display .show_response (progress , response )
81107
82108 if not response .error :
83109 for follow_up in test .follow_up_prompts :
84- display .show_prompt (progress , follow_up , "" )
85110 display .start_thinking_timer (progress )
86-
87111 response = self .client .generate (follow_up , "" )
88112 display .stop_thinking_timer ()
89113 responses .append (response )
@@ -93,9 +117,12 @@ def run_single_test(
93117 break
94118 else :
95119 # Single-turn test
96- display .show_prompt (progress , test .prompt , test .system_prompt )
97- display .start_thinking_timer (progress )
120+ if show_prompts and current_rep == 1 :
121+ display .show_prompt (
122+ progress , test .prompt , test .system_prompt , show_display = True
123+ )
98124
125+ display .start_thinking_timer (progress )
99126 response = self .client .generate (test .prompt , test .system_prompt )
100127 display .stop_thinking_timer ()
101128 responses .append (response )
@@ -116,7 +143,10 @@ def run_single_test(
116143
117144 # Show evaluation results
118145 display .show_evaluation (progress , evaluation )
119- display .complete_test (progress , evaluation )
146+
147+ # Only show completion message on last repetition
148+ if current_rep == total_reps :
149+ display .complete_test (progress , evaluation )
120150
121151 except Exception as e :
122152 # Ensure timer is stopped on any exception
@@ -187,8 +217,66 @@ def _evaluate_single_response(
187217 # Default evaluation - can be overridden by subclasses
188218 return self .evaluator .evaluate_deception (test .prompt , response .content )
189219
220+ def run_test_with_repeats (
221+ self , test : BaseTest , repeat_count : int , current_num : int = 1 , total_tests : int = 1
222+ ) -> list [tuple [BaseTest , list [ModelResponse ], EvaluationResult ]]:
223+ """Run a single test multiple times consecutively"""
224+ display = get_display ()
225+ results = []
226+ repetition_summaries = []
227+
228+ for rep in range (1 , repeat_count + 1 ):
229+ try :
230+ result = self .run_single_test (
231+ test ,
232+ current_num = current_num ,
233+ total_tests = total_tests ,
234+ current_rep = rep ,
235+ total_reps = repeat_count ,
236+ show_prompts = True ,
237+ )
238+ test_obj , responses , evaluation = result
239+ results .append (result )
240+
241+ # Collect summary data for this repetition
242+ rep_summary = {
243+ "is_vulnerable" : evaluation .is_vulnerable ,
244+ "confidence" : evaluation .confidence ,
245+ "response_time" : sum (r .response_time for r in responses ) if responses else 0 ,
246+ }
247+ repetition_summaries .append (rep_summary )
248+
249+ except Exception as e :
250+ # Ensure any running timer is stopped
251+ display .stop_thinking_timer ()
252+ display .error (f"Failed: { test .name } repetition { rep } - { e } " )
253+ # Create error result
254+ error_response = ModelResponse (content = "" , model = self .client .model , error = str (e ))
255+ error_evaluation = EvaluationResult (
256+ is_vulnerable = False ,
257+ category = self .get_vulnerability_category (),
258+ confidence = 0.0 ,
259+ evidence = [],
260+ reasoning = f"Test failed: { e } " ,
261+ )
262+ results .append ((test , [error_response ], error_evaluation ))
263+
264+ # Add error to summary
265+ rep_summary = {
266+ "is_vulnerable" : False ,
267+ "confidence" : 0.0 ,
268+ "response_time" : 0 ,
269+ }
270+ repetition_summaries .append (rep_summary )
271+
272+ # Show vulnerability summary after all repetitions (if more than 1)
273+ if repeat_count > 1 :
274+ display .show_vulnerability_summary (test .test_id , repetition_summaries )
275+
276+ return results
277+
190278 def run_category_tests (
191- self , category : str | None = None , test_id : str | None = None
279+ self , category : str | None = None , test_id : str | None = None , repeat_count : int = 1
192280 ) -> list [tuple [BaseTest , list [ModelResponse ], EvaluationResult ]]:
193281 """Run all tests in category with filtering support"""
194282 display = get_display ()
@@ -209,23 +297,40 @@ def run_category_tests(
209297 display .start_category (self .get_category_name (), len (test_cases ))
210298
211299 for i , test in enumerate (test_cases , 1 ):
212- try :
213- result = self .run_single_test (test , current_num = i , total_tests = len (test_cases ))
214- results .append (result )
215- except Exception as e :
216- # Ensure any running timer is stopped
217- display .stop_thinking_timer ()
218- display .error (f"Failed: { test .name } - { e } " )
219- # Create error result
220- error_response = ModelResponse (content = "" , model = self .client .model , error = str (e ))
221- error_evaluation = EvaluationResult (
222- is_vulnerable = False ,
223- category = self .get_vulnerability_category (),
224- confidence = 0.0 ,
225- evidence = [],
226- reasoning = f"Test failed: { e } " ,
300+ if repeat_count > 1 :
301+ # Run test with repetitions
302+ test_results = self .run_test_with_repeats (
303+ test , repeat_count , current_num = i , total_tests = len (test_cases )
227304 )
228- results .append ((test , [error_response ], error_evaluation ))
305+ results .extend (test_results )
306+ else :
307+ # Single run (normal behavior)
308+ try :
309+ result = self .run_single_test (
310+ test ,
311+ current_num = i ,
312+ total_tests = len (test_cases ),
313+ current_rep = 1 ,
314+ total_reps = 1 ,
315+ show_prompts = True ,
316+ )
317+ results .append (result )
318+ except Exception as e :
319+ # Ensure any running timer is stopped
320+ display .stop_thinking_timer ()
321+ display .error (f"Failed: { test .name } - { e } " )
322+ # Create error result
323+ error_response = ModelResponse (
324+ content = "" , model = self .client .model , error = str (e )
325+ )
326+ error_evaluation = EvaluationResult (
327+ is_vulnerable = False ,
328+ category = self .get_vulnerability_category (),
329+ confidence = 0.0 ,
330+ evidence = [],
331+ reasoning = f"Test failed: { e } " ,
332+ )
333+ results .append ((test , [error_response ], error_evaluation ))
229334
230335 return results
231336
@@ -298,12 +403,13 @@ def run_category_tests_generic(
298403 client : OllamaClient ,
299404 category : str | None = None ,
300405 test_id : str | None = None ,
406+ repeat_count : int = 1 ,
301407) -> dict [str , Any ]:
302408 """Generic runner function for any tester class"""
303409 display = get_display ()
304410 tester = tester_class (client )
305411
306- results = tester .run_category_tests (category , test_id )
412+ results = tester .run_category_tests (category , test_id , repeat_count )
307413 analysis = tester .analyze_results (results )
308414
309415 # Display final summary
0 commit comments