Skip to content

Commit 8a9b090

Browse files
Fix: Small update in how ML-based prompt injection determines final result (#6439)
1 parent 8880192 commit 8a9b090

1 file changed

Lines changed: 48 additions & 11 deletions

File tree

crates/goose/src/security/scanner.rs

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ impl PromptInjectionScanner {
4646
fn create_classifier_from_config() -> Result<ClassificationClient> {
4747
let config = Config::global();
4848

49-
let model_name = config
49+
let mut model_name = config
5050
.get_param::<String>("SECURITY_PROMPT_CLASSIFIER_MODEL")
5151
.ok()
5252
.filter(|s| !s.trim().is_empty());
@@ -59,6 +59,23 @@ impl PromptInjectionScanner {
5959
.ok()
6060
.filter(|s| !s.trim().is_empty());
6161

62+
if model_name.is_none() {
63+
if let Ok(mapping_json) = std::env::var("SECURITY_ML_MODEL_MAPPING") {
64+
if let Ok(mapping) = serde_json::from_str::<
65+
crate::security::classification_client::ModelMappingConfig,
66+
>(&mapping_json)
67+
{
68+
if let Some(first_model) = mapping.models.keys().next() {
69+
tracing::info!(
70+
default_model = %first_model,
71+
"SECURITY_ML_MODEL_MAPPING available but no model selected - using first available model as default"
72+
);
73+
model_name = Some(first_model.clone());
74+
}
75+
}
76+
}
77+
}
78+
6279
tracing::debug!(
6380
model_name = ?model_name,
6481
has_endpoint = endpoint.is_some(),
@@ -106,20 +123,23 @@ impl PromptInjectionScanner {
106123
self.scan_conversation(messages)
107124
);
108125

109-
let highest_confidence_result =
110-
self.select_highest_confidence_result(tool_result?, context_result?);
126+
let tool_result = tool_result?;
127+
let context_result = context_result?;
111128
let threshold = self.get_threshold_from_config();
112129

130+
let final_result =
131+
self.select_result_with_context_awareness(tool_result, context_result, threshold);
132+
113133
tracing::info!(
114-
"Security analysis complete: confidence={:.3}, malicious={}",
115-
highest_confidence_result.confidence,
116-
highest_confidence_result.confidence >= threshold
134+
"Security analysis complete: confidence={:.3}, malicious={}",
135+
final_result.confidence,
136+
final_result.confidence >= threshold
117137
);
118138

119139
Ok(ScanResult {
120-
is_malicious: highest_confidence_result.confidence >= threshold,
121-
confidence: highest_confidence_result.confidence,
122-
explanation: self.build_explanation(&highest_confidence_result, threshold),
140+
is_malicious: final_result.confidence >= threshold,
141+
confidence: final_result.confidence,
142+
explanation: self.build_explanation(&final_result, threshold),
123143
})
124144
}
125145

@@ -169,12 +189,29 @@ impl PromptInjectionScanner {
169189
})
170190
}
171191

172-
fn select_highest_confidence_result(
192+
fn select_result_with_context_awareness(
173193
&self,
174194
tool_result: DetailedScanResult,
175195
context_result: DetailedScanResult,
196+
threshold: f32,
176197
) -> DetailedScanResult {
177-
if tool_result.confidence >= context_result.confidence {
198+
let context_is_safe = context_result
199+
.ml_confidence
200+
.is_some_and(|conf| conf < threshold);
201+
202+
let tool_has_only_non_critical = !tool_result.pattern_matches.is_empty()
203+
&& tool_result
204+
.pattern_matches
205+
.iter()
206+
.all(|m| m.threat.risk_level != crate::security::patterns::RiskLevel::Critical);
207+
208+
if context_is_safe && tool_has_only_non_critical {
209+
DetailedScanResult {
210+
confidence: 0.0,
211+
pattern_matches: Vec::new(),
212+
ml_confidence: context_result.ml_confidence,
213+
}
214+
} else if tool_result.confidence >= context_result.confidence {
178215
tool_result
179216
} else {
180217
context_result

0 commit comments

Comments
 (0)