diff --git a/src/orchestrator/handlers/chat_completions_detection/streaming.rs b/src/orchestrator/handlers/chat_completions_detection/streaming.rs index 60c5ddd1..9b4360ce 100644 --- a/src/orchestrator/handlers/chat_completions_detection/streaming.rs +++ b/src/orchestrator/handlers/chat_completions_detection/streaming.rs @@ -14,7 +14,7 @@ limitations under the License. */ -use std::{collections::HashMap, sync::Arc}; +use std::{collections::HashMap, sync::Arc, time::Duration}; use futures::{StreamExt, future::try_join_all}; use opentelemetry::trace::TraceId; @@ -39,6 +39,14 @@ use crate::{ }, }; +/// Timeout duration when waiting for completion state entries to become available. +/// This handles the race condition where detectors respond faster than the LLM stream inserts completions. +/// +/// The waiting mechanism uses cooperative yielding via `tokio::task::yield_now()`, not busy waiting. +/// Each yield returns immediately in the normal case (microseconds), so this timeout only matters +/// if the generation server is exceptionally slow or experiencing issues. +const COMPLETION_WAIT_TIMEOUT: Duration = Duration::from_secs(30); + pub async fn handle_streaming( ctx: Arc, task: ChatCompletionsDetectionTask, @@ -563,14 +571,33 @@ async fn handle_whole_doc_detection( } /// Builds a response with output detections. -fn output_detection_response( +async fn output_detection_response( completion_state: &Arc>, choice_index: u32, chunk: Chunk, detections: Vec, ) -> Result { - // Get chat completions for this choice index - let chat_completions = completion_state.completions.get(&choice_index).unwrap(); + // Wait for entry to exist (yields to other tasks until ready) + let chat_completions = { + let wait_for_entry = async { + loop { + if let Some(entry) = completion_state.completions.get(&choice_index) { + return entry; + } + tokio::task::yield_now().await; + } + }; + + match tokio::time::timeout(COMPLETION_WAIT_TIMEOUT, wait_for_entry).await { + Ok(entry) => entry, + Err(_) => { + return Err(Error::Other(format!( + "completion entry for choice_index {} not ready after {:?} timeout", + choice_index, COMPLETION_WAIT_TIMEOUT + ))); + } + } + }; // Get range of chat completions for this chunk let chat_completions = chat_completions .range(chunk.input_start_index..=chunk.input_end_index) @@ -709,6 +736,7 @@ async fn process_detection_batch_stream( Ok((choice_index, chunk, detections)) => { let input_end_index = chunk.input_end_index; match output_detection_response(&completion_state, choice_index, chunk, detections) + .await { Ok(chat_completion) => { // Send chat completion to response channel @@ -718,8 +746,29 @@ async fn process_detection_batch_stream( return; } // If this is the final chat completion chunk with content, send chat completion chunk with finish reason - let chat_completions = - completion_state.completions.get(&choice_index).unwrap(); + // Wait for entry to exist (yields to other tasks until ready) + let chat_completions = { + let wait_for_entry = async { + loop { + if let Some(entry) = + completion_state.completions.get(&choice_index) + { + return entry; + } + tokio::task::yield_now().await; + } + }; + + match tokio::time::timeout(COMPLETION_WAIT_TIMEOUT, wait_for_entry) + .await + { + Ok(entry) => entry, + Err(_) => { + error!(%trace_id, %choice_index, "completion entry not ready after {:?} timeout", COMPLETION_WAIT_TIMEOUT); + return; + } + } + }; if chat_completions.keys().rev().nth(1) == Some(&input_end_index) && let Some((_, chat_completion)) = chat_completions.last_key_value() && chat_completion diff --git a/tests/chat_completions_streaming.rs b/tests/chat_completions_streaming.rs index bfb7d2cb..f0eaa914 100644 --- a/tests/chat_completions_streaming.rs +++ b/tests/chat_completions_streaming.rs @@ -5512,3 +5512,142 @@ async fn detector_internal_server_error() -> Result<(), anyhow::Error> { Ok(()) } + +#[test(tokio::test)] +async fn fast_detector_race_condition() -> Result<(), anyhow::Error> { + // This test verifies the fix for the race condition where detectors respond faster + // than the LLM stream inserts completions into shared state. + + let mut openai_server = MockServer::new_http("openai"); + openai_server.mock(|when, then| { + when.post() + .path(CHAT_COMPLETIONS_ENDPOINT) + .json(json!({ + "stream": true, + "model": "test-0B", + "messages": [ + Message { role: Role::User, content: Some(Content::Text("Hello!".into())), ..Default::default()}, + ] + }) + ); + then.text_stream(sse([ + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + role: Some(Role::Assistant), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + delta: ChatCompletionDelta { + content: Some("Hi!".into()), + ..Default::default() + }, + ..Default::default() + }], + ..Default::default() + }, + ChatCompletionChunk { + id: "chatcmpl-test".into(), + object: "chat.completion.chunk".into(), + created: 1749227854, + model: "test-0B".into(), + choices: vec![ChatCompletionChunkChoice { + index: 0, + finish_reason: Some("stop".into()), + ..Default::default() + }], + ..Default::default() + }, + ])); + }); + + let mut sentence_chunker_server = MockServer::new_grpc("sentence_chunker"); + sentence_chunker_server.mock(|when, then| { + when.post() + .path(CHUNKER_STREAMING_ENDPOINT) + .header(CHUNKER_MODEL_ID_HEADER_NAME, "sentence_chunker") + .pb_stream(vec![BidiStreamingChunkerTokenizationTaskRequest { + text_stream: "Hi!".into(), + input_index_stream: 1, + }]); + then.pb_stream(vec![ChunkerTokenizationStreamResult { + results: vec![Token { + start: 0, + end: 3, + text: "Hi!".into(), + }], + token_count: 0, + processed_index: 3, + start_index: 0, + input_start_index: 1, + input_end_index: 1, + }]); + }); + + // Fast detector that responds immediately with no detections + let mut pii_detector_sentence_server = MockServer::new_http("pii_detector_sentence"); + pii_detector_sentence_server.mock(|when, then| { + when.post() + .path(TEXT_CONTENTS_DETECTOR_ENDPOINT) + .header("detector-id", PII_DETECTOR_SENTENCE) + .json(ContentAnalysisRequest { + contents: vec!["Hi!".into()], + detector_params: DetectorParams::default(), + }); + // Detector responds immediately with no detections + then.json(json!([[]])); + }); + + let test_server = TestOrchestratorServer::builder() + .config_path(ORCHESTRATOR_CONFIG_FILE_PATH) + .openai_server(&openai_server) + .chunker_servers([&sentence_chunker_server]) + .detector_servers([&pii_detector_sentence_server]) + .build() + .await?; + + let response = test_server + .post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT) + .json(&json!({ + "stream": true, + "model": "test-0B", + "detectors": { + "input": {}, + "output": { + "pii_detector_sentence": {} + } + }, + "messages": [ + Message { role: Role::User, content: Some(Content::Text("Hello!".into())), ..Default::default()}, + ], + })) + .send() + .await?; + assert_eq!(response.status(), StatusCode::OK); + + let sse_stream: SseStream = SseStream::new(response.bytes_stream()); + let messages = sse_stream.try_collect::>().await?; + + // The key assertion: request completes successfully without panicking + + assert!(messages.len() >= 2, "should complete without panicking"); + assert_eq!(messages[0].choices[0].delta.role, Some(Role::Assistant)); + assert!(messages.last().unwrap().choices[0].finish_reason.is_some()); + + Ok(()) +}