Skip to content

Commit 41b7f2b

Browse files
committed
feat: add feature-level fusion, output toxicity detection, and streaming output moderation (Phase 2)
- R5: Feature-level fusion architecture (DeBERTa embeddings + heuristic features → FC classifier) - R6: Output toxicity detection via toxic-bert with configurable thresholds - R7: Streaming output moderation with PII/secret/toxicity checks and early-stopping - New modules: feature_extraction, fusion_classifier, output_analyzer, toxicity_detector - Config: fusion_enabled, output_safety.*, streaming_analysis.output_enabled
1 parent 583c5c2 commit 41b7f2b

11 files changed

Lines changed: 2437 additions & 12 deletions

File tree

config.example.yaml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,31 @@ streaming_analysis:
271271
# Number of completion tokens between each incremental analysis check.
272272
# Lower values detect threats faster but add marginal CPU overhead per chunk.
273273
token_interval: 50
274+
# Enable output-side analysis during SSE streaming (PII, secrets, toxicity
275+
# on response content in real-time). Requires output_safety.enabled = true.
276+
output_enabled: false
277+
# If a critical output safety finding is detected mid-stream, inject a
278+
# warning into the SSE stream and terminate. Use with caution — this will
279+
# cut off the LLM response mid-generation.
280+
early_stop_on_critical: false
281+
282+
# ---------------------------------------------------------------------------
283+
# Output safety — toxicity detection and response content analysis (R6)
284+
# ---------------------------------------------------------------------------
285+
286+
output_safety:
287+
# Enable output safety analysis on LLM responses. When enabled, the proxy
288+
# analyses response content for toxicity, PII leakage, and secret exposure.
289+
enabled: false
290+
# Enable toxicity detection on response content. Uses a BERT-based classifier
291+
# (unitary/toxic-bert) or falls back to keyword-based detection.
292+
toxicity_enabled: false
293+
# Confidence threshold for toxicity detection (0.0–1.0). Categories scoring
294+
# above this threshold are reported as findings.
295+
toxicity_threshold: 0.7
296+
# Block (replace) the entire response if critical toxicity is detected
297+
# (severe_toxic, threat, or score >= 0.9). Use with caution.
298+
block_on_critical: false
274299

275300
# ---------------------------------------------------------------------------
276301
# ML-based security analysis

crates/llmtrace-core/src/lib.rs

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -915,6 +915,9 @@ pub struct ProxyConfig {
915915
/// PII detection and redaction configuration.
916916
#[serde(default)]
917917
pub pii: PiiConfig,
918+
/// Output safety configuration (toxicity detection, output analysis).
919+
#[serde(default)]
920+
pub output_safety: OutputSafetyConfig,
918921
/// Graceful shutdown configuration.
919922
#[serde(default)]
920923
pub shutdown: ShutdownConfig,
@@ -952,6 +955,7 @@ impl Default for ProxyConfig {
952955
anomaly_detection: AnomalyDetectionConfig::default(),
953956
streaming_analysis: StreamingAnalysisConfig::default(),
954957
pii: PiiConfig::default(),
958+
output_safety: OutputSafetyConfig::default(),
955959
shutdown: ShutdownConfig::default(),
956960
}
957961
}
@@ -1365,6 +1369,12 @@ pub struct StreamingAnalysisConfig {
13651369
/// Number of tokens between each incremental analysis check.
13661370
#[serde(default = "default_streaming_token_interval")]
13671371
pub token_interval: u32,
1372+
/// Enable output-side analysis during SSE streaming (PII, secrets, toxicity on response content).
1373+
#[serde(default)]
1374+
pub output_enabled: bool,
1375+
/// If a critical finding is detected mid-stream, inject a warning and stop.
1376+
#[serde(default)]
1377+
pub early_stop_on_critical: bool,
13681378
}
13691379

13701380
fn default_streaming_token_interval() -> u32 {
@@ -1376,6 +1386,58 @@ impl Default for StreamingAnalysisConfig {
13761386
Self {
13771387
enabled: false,
13781388
token_interval: default_streaming_token_interval(),
1389+
output_enabled: false,
1390+
early_stop_on_critical: false,
1391+
}
1392+
}
1393+
}
1394+
1395+
// ---------------------------------------------------------------------------
1396+
// Output safety configuration
1397+
// ---------------------------------------------------------------------------
1398+
1399+
/// Output safety configuration for response content analysis.
1400+
///
1401+
/// When enabled, the proxy analyses LLM response content for toxicity,
1402+
/// PII leakage, and secret exposure. This is a post-processing step that
1403+
/// runs after the upstream response is received.
1404+
///
1405+
/// # Example (YAML)
1406+
///
1407+
/// ```yaml
1408+
/// output_safety:
1409+
/// enabled: true
1410+
/// toxicity_enabled: true
1411+
/// toxicity_threshold: 0.7
1412+
/// block_on_critical: false
1413+
/// ```
1414+
#[derive(Debug, Clone, Serialize, Deserialize)]
1415+
pub struct OutputSafetyConfig {
1416+
/// Enable output safety analysis on LLM responses.
1417+
#[serde(default)]
1418+
pub enabled: bool,
1419+
/// Enable toxicity detection on response content.
1420+
#[serde(default)]
1421+
pub toxicity_enabled: bool,
1422+
/// Confidence threshold for toxicity detection (0.0–1.0).
1423+
#[serde(default = "default_toxicity_threshold")]
1424+
pub toxicity_threshold: f32,
1425+
/// Block (replace) the response if critical toxicity is detected.
1426+
#[serde(default)]
1427+
pub block_on_critical: bool,
1428+
}
1429+
1430+
fn default_toxicity_threshold() -> f32 {
1431+
0.7
1432+
}
1433+
1434+
impl Default for OutputSafetyConfig {
1435+
fn default() -> Self {
1436+
Self {
1437+
enabled: false,
1438+
toxicity_enabled: false,
1439+
toxicity_threshold: default_toxicity_threshold(),
1440+
block_on_critical: false,
13791441
}
13801442
}
13811443
}
@@ -1503,6 +1565,19 @@ pub struct SecurityAnalysisConfig {
15031565
/// HuggingFace model ID for NER-based PII detection.
15041566
#[serde(default = "default_ner_model")]
15051567
pub ner_model: String,
1568+
/// Enable feature-level fusion classifier (ADR-013).
1569+
///
1570+
/// When `true`, the ensemble concatenates DeBERTa embeddings with heuristic
1571+
/// feature vectors and feeds them through a learned fusion classifier instead
1572+
/// of combining scores after independent classification.
1573+
#[serde(default)]
1574+
pub fusion_enabled: bool,
1575+
/// Optional file path for trained fusion classifier weights.
1576+
///
1577+
/// When `None`, the fusion classifier is initialised with random weights
1578+
/// (suitable for architecture validation; not for production inference).
1579+
#[serde(default)]
1580+
pub fusion_model_path: Option<String>,
15061581
}
15071582

15081583
fn default_ml_model() -> String {
@@ -1540,6 +1615,8 @@ impl Default for SecurityAnalysisConfig {
15401615
ml_download_timeout_seconds: default_ml_download_timeout_seconds(),
15411616
ner_enabled: false,
15421617
ner_model: default_ner_model(),
1618+
fusion_enabled: false,
1619+
fusion_model_path: None,
15431620
}
15441621
}
15451622
}
@@ -2525,6 +2602,8 @@ mod tests {
25252602
ml_download_timeout_seconds: 300,
25262603
ner_enabled: false,
25272604
ner_model: default_ner_model(),
2605+
fusion_enabled: false,
2606+
fusion_model_path: None,
25282607
},
25292608
otel_ingest: OtelIngestConfig::default(),
25302609
auth: AuthConfig::default(),
@@ -2534,6 +2613,7 @@ mod tests {
25342613
pii: PiiConfig {
25352614
action: PiiAction::AlertAndRedact,
25362615
},
2616+
output_safety: OutputSafetyConfig::default(),
25372617
shutdown: ShutdownConfig::default(),
25382618
};
25392619

@@ -2977,6 +3057,8 @@ mod tests {
29773057
assert_eq!(config.security_analysis.ml_download_timeout_seconds, 300);
29783058
assert!(!config.security_analysis.ner_enabled);
29793059
assert_eq!(config.security_analysis.ner_model, "dslim/bert-base-NER");
3060+
assert!(!config.security_analysis.fusion_enabled);
3061+
assert!(config.security_analysis.fusion_model_path.is_none());
29803062
}
29813063

29823064
#[test]
@@ -2993,6 +3075,8 @@ mod tests {
29933075
assert_eq!(config.ml_download_timeout_seconds, 300);
29943076
assert!(!config.ner_enabled);
29953077
assert_eq!(config.ner_model, "dslim/bert-base-NER");
3078+
assert!(!config.fusion_enabled);
3079+
assert!(config.fusion_model_path.is_none());
29963080
}
29973081

29983082
#[test]
@@ -3006,6 +3090,8 @@ mod tests {
30063090
ml_download_timeout_seconds: 600,
30073091
ner_enabled: true,
30083092
ner_model: "dslim/bert-base-NER".to_string(),
3093+
fusion_enabled: false,
3094+
fusion_model_path: None,
30093095
};
30103096
let json = serde_json::to_string(&config).unwrap();
30113097
let deserialized: SecurityAnalysisConfig = serde_json::from_str(&json).unwrap();

crates/llmtrace-proxy/src/proxy.rs

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
use crate::circuit_breaker::CircuitBreaker;
88
use crate::cost::CostEstimator;
99
use crate::provider::{self, ParsedResponse};
10-
use crate::streaming::{StreamingAccumulator, StreamingSecurityMonitor};
10+
use crate::streaming::{StreamingAccumulator, StreamingOutputMonitor, StreamingSecurityMonitor};
1111
use axum::body::Body;
1212
use axum::extract::State;
1313
use axum::http::{HeaderMap, Request, Response, StatusCode};
@@ -433,6 +433,15 @@ pub async fn proxy_handler(
433433
} else {
434434
None
435435
};
436+
// Initialise the streaming output monitor for response-side analysis (R7).
437+
let mut output_monitor = if is_streaming {
438+
StreamingOutputMonitor::new(
439+
&state_bg.config.streaming_analysis,
440+
&state_bg.config.output_safety,
441+
)
442+
} else {
443+
None
444+
};
436445
let mut raw_collected = Vec::new();
437446
let mut ttft_ms: Option<u64> = None;
438447

@@ -466,6 +475,36 @@ pub async fn proxy_handler(
466475
}
467476
}
468477
}
478+
479+
// --- Real-time streaming OUTPUT analysis (R7) ---
480+
if let Some(ref mut out_mon) = output_monitor {
481+
if out_mon.should_analyze(acc.completion_token_count) {
482+
let new_findings = out_mon
483+
.analyze_incremental(&acc.content, acc.completion_token_count);
484+
if !new_findings.is_empty() {
485+
info!(
486+
%trace_id,
487+
count = new_findings.len(),
488+
tokens = acc.completion_token_count,
489+
"Streaming output safety findings detected mid-stream"
490+
);
491+
if let Some(ref engine) = state_bg.alert_engine {
492+
engine.check_and_alert(trace_id, tenant_id, &new_findings);
493+
}
494+
}
495+
}
496+
497+
// Early stop: inject warning and terminate stream
498+
if out_mon.should_early_stop() {
499+
warn!(
500+
%trace_id,
501+
"Critical output safety issue detected — early stopping stream"
502+
);
503+
let warning = StreamingOutputMonitor::early_stop_sse_event();
504+
let _ = body_sender.send(Ok(Bytes::from(warning))).await;
505+
break;
506+
}
507+
}
469508
}
470509
raw_collected.extend_from_slice(&bytes);
471510
if body_sender.send(Ok(bytes)).await.is_err() {
@@ -500,12 +539,33 @@ pub async fn proxy_handler(
500539
}
501540
}
502541

542+
// Run one final streaming OUTPUT analysis flush.
543+
if let (Some(ref acc), Some(ref mut out_mon)) = (&sse_accumulator, &mut output_monitor) {
544+
let final_findings =
545+
out_mon.analyze_incremental(&acc.content, acc.completion_token_count);
546+
if !final_findings.is_empty() {
547+
info!(
548+
%trace_id,
549+
count = final_findings.len(),
550+
"Streaming output safety findings in final flush"
551+
);
552+
if let Some(ref engine) = state_bg.alert_engine {
553+
engine.check_and_alert(trace_id, tenant_id, &final_findings);
554+
}
555+
}
556+
}
557+
503558
// Collect streaming security findings for attachment to the trace span.
504-
let streaming_findings: Vec<SecurityFinding> = streaming_monitor
559+
let mut streaming_findings: Vec<SecurityFinding> = streaming_monitor
505560
.as_mut()
506561
.map(|m| m.take_findings())
507562
.unwrap_or_default();
508563

564+
// Merge in streaming output findings.
565+
if let Some(ref mut out_mon) = output_monitor {
566+
streaming_findings.extend(out_mon.take_findings());
567+
}
568+
509569
// Build the captured interaction with streaming metrics if applicable
510570
let (response_text, prompt_tokens, completion_tokens, total_tokens) =
511571
if let Some(acc) = sse_accumulator {
@@ -744,7 +804,7 @@ async fn run_security_analysis(
744804
parameters: std::collections::HashMap::new(),
745805
};
746806

747-
match state
807+
let mut all_findings = match state
748808
.security
749809
.analyze_interaction(&captured.prompt_text, &captured.response_text, &context)
750810
.await
@@ -775,7 +835,25 @@ async fn run_security_analysis(
775835
error!(trace_id = %captured.trace_id, "Security analysis failed: {}", e);
776836
Vec::new()
777837
}
838+
};
839+
840+
// --- Output safety analysis (R6) ---
841+
if state.config.output_safety.enabled && !captured.response_text.is_empty() {
842+
let output_analyzer =
843+
llmtrace_security::OutputAnalyzer::new_with_fallback(&state.config.output_safety);
844+
let result = output_analyzer.analyze_output(&captured.response_text);
845+
if !result.findings.is_empty() {
846+
info!(
847+
trace_id = %captured.trace_id,
848+
finding_count = result.findings.len(),
849+
has_critical = result.has_critical_toxicity,
850+
"Output safety findings detected"
851+
);
852+
all_findings.extend(result.findings);
853+
}
778854
}
855+
856+
all_findings
779857
}
780858

781859
/// Store a trace event enriched with security findings.

0 commit comments

Comments
 (0)