@@ -46,7 +46,7 @@ impl PromptInjectionScanner {
4646 fn create_classifier_from_config ( ) -> Result < ClassificationClient > {
4747 let config = Config :: global ( ) ;
4848
49- let model_name = config
49+ let mut model_name = config
5050 . get_param :: < String > ( "SECURITY_PROMPT_CLASSIFIER_MODEL" )
5151 . ok ( )
5252 . filter ( |s| !s. trim ( ) . is_empty ( ) ) ;
@@ -59,6 +59,23 @@ impl PromptInjectionScanner {
5959 . ok ( )
6060 . filter ( |s| !s. trim ( ) . is_empty ( ) ) ;
6161
62+ if model_name. is_none ( ) {
63+ if let Ok ( mapping_json) = std:: env:: var ( "SECURITY_ML_MODEL_MAPPING" ) {
64+ if let Ok ( mapping) = serde_json:: from_str :: <
65+ crate :: security:: classification_client:: ModelMappingConfig ,
66+ > ( & mapping_json)
67+ {
68+ if let Some ( first_model) = mapping. models . keys ( ) . next ( ) {
69+ tracing:: info!(
70+ default_model = %first_model,
71+ "SECURITY_ML_MODEL_MAPPING available but no model selected - using first available model as default"
72+ ) ;
73+ model_name = Some ( first_model. clone ( ) ) ;
74+ }
75+ }
76+ }
77+ }
78+
6279 tracing:: debug!(
6380 model_name = ?model_name,
6481 has_endpoint = endpoint. is_some( ) ,
@@ -106,20 +123,23 @@ impl PromptInjectionScanner {
106123 self . scan_conversation( messages)
107124 ) ;
108125
109- let highest_confidence_result =
110- self . select_highest_confidence_result ( tool_result? , context_result?) ;
126+ let tool_result = tool_result? ;
127+ let context_result = context_result?;
111128 let threshold = self . get_threshold_from_config ( ) ;
112129
130+ let final_result =
131+ self . select_result_with_context_awareness ( tool_result, context_result, threshold) ;
132+
113133 tracing:: info!(
114- "✅ Security analysis complete: confidence={:.3}, malicious={}" ,
115- highest_confidence_result . confidence,
116- highest_confidence_result . confidence >= threshold
134+ "Security analysis complete: confidence={:.3}, malicious={}" ,
135+ final_result . confidence,
136+ final_result . confidence >= threshold
117137 ) ;
118138
119139 Ok ( ScanResult {
120- is_malicious : highest_confidence_result . confidence >= threshold,
121- confidence : highest_confidence_result . confidence ,
122- explanation : self . build_explanation ( & highest_confidence_result , threshold) ,
140+ is_malicious : final_result . confidence >= threshold,
141+ confidence : final_result . confidence ,
142+ explanation : self . build_explanation ( & final_result , threshold) ,
123143 } )
124144 }
125145
@@ -169,12 +189,29 @@ impl PromptInjectionScanner {
169189 } )
170190 }
171191
172- fn select_highest_confidence_result (
192+ fn select_result_with_context_awareness (
173193 & self ,
174194 tool_result : DetailedScanResult ,
175195 context_result : DetailedScanResult ,
196+ threshold : f32 ,
176197 ) -> DetailedScanResult {
177- if tool_result. confidence >= context_result. confidence {
198+ let context_is_safe = context_result
199+ . ml_confidence
200+ . is_some_and ( |conf| conf < threshold) ;
201+
202+ let tool_has_only_non_critical = !tool_result. pattern_matches . is_empty ( )
203+ && tool_result
204+ . pattern_matches
205+ . iter ( )
206+ . all ( |m| m. threat . risk_level != crate :: security:: patterns:: RiskLevel :: Critical ) ;
207+
208+ if context_is_safe && tool_has_only_non_critical {
209+ DetailedScanResult {
210+ confidence : 0.0 ,
211+ pattern_matches : Vec :: new ( ) ,
212+ ml_confidence : context_result. ml_confidence ,
213+ }
214+ } else if tool_result. confidence >= context_result. confidence {
178215 tool_result
179216 } else {
180217 context_result
0 commit comments