Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
limitations under the License.

*/
use std::{collections::HashMap, sync::Arc};
use std::{collections::HashMap, sync::Arc, time::Duration};

use futures::{StreamExt, future::try_join_all};
use opentelemetry::trace::TraceId;
Expand All @@ -39,6 +39,14 @@ use crate::{
},
};

/// Timeout duration when waiting for completion state entries to become available.
/// This handles the race condition where detectors respond faster than the LLM stream inserts completions.
///
/// The waiting mechanism uses cooperative yielding via `tokio::task::yield_now()`, not busy waiting.
/// Each yield returns immediately in the normal case (microseconds), so this timeout only matters
/// if the generation server is exceptionally slow or experiencing issues.
const COMPLETION_WAIT_TIMEOUT: Duration = Duration::from_secs(30);

pub async fn handle_streaming(
ctx: Arc<Context>,
task: ChatCompletionsDetectionTask,
Expand Down Expand Up @@ -563,14 +571,33 @@ async fn handle_whole_doc_detection(
}

/// Builds a response with output detections.
fn output_detection_response(
async fn output_detection_response(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need to be async ? or can we do the await operation inside tokio runtime ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question! IIUC yield_now()) only works in async context, so I think it should remain this way; please let me know what you think!

completion_state: &Arc<CompletionState<ChatCompletionChunk>>,
choice_index: u32,
chunk: Chunk,
detections: Vec<Detection>,
) -> Result<ChatCompletionChunk, Error> {
// Get chat completions for this choice index
let chat_completions = completion_state.completions.get(&choice_index).unwrap();
// Wait for entry to exist (yields to other tasks until ready)
let chat_completions = {
let wait_for_entry = async {
loop {
if let Some(entry) = completion_state.completions.get(&choice_index) {
return entry;
}
tokio::task::yield_now().await;
}
};

match tokio::time::timeout(COMPLETION_WAIT_TIMEOUT, wait_for_entry).await {
Ok(entry) => entry,
Err(_) => {
return Err(Error::Other(format!(
"completion entry for choice_index {} not ready after {:?} timeout",
choice_index, COMPLETION_WAIT_TIMEOUT
)));
}
}
};
// Get range of chat completions for this chunk
let chat_completions = chat_completions
.range(chunk.input_start_index..=chunk.input_end_index)
Expand Down Expand Up @@ -709,6 +736,7 @@ async fn process_detection_batch_stream(
Ok((choice_index, chunk, detections)) => {
let input_end_index = chunk.input_end_index;
match output_detection_response(&completion_state, choice_index, chunk, detections)
.await
{
Ok(chat_completion) => {
// Send chat completion to response channel
Expand All @@ -718,8 +746,29 @@ async fn process_detection_batch_stream(
return;
}
// If this is the final chat completion chunk with content, send chat completion chunk with finish reason
let chat_completions =
completion_state.completions.get(&choice_index).unwrap();
// Wait for entry to exist (yields to other tasks until ready)
let chat_completions = {
let wait_for_entry = async {
loop {
if let Some(entry) =
completion_state.completions.get(&choice_index)
{
return entry;
}
tokio::task::yield_now().await;
}
};

match tokio::time::timeout(COMPLETION_WAIT_TIMEOUT, wait_for_entry)
.await
{
Ok(entry) => entry,
Err(_) => {
error!(%trace_id, %choice_index, "completion entry not ready after {:?} timeout", COMPLETION_WAIT_TIMEOUT);
return;
}
}
};
if chat_completions.keys().rev().nth(1) == Some(&input_end_index)
&& let Some((_, chat_completion)) = chat_completions.last_key_value()
&& chat_completion
Expand Down
139 changes: 139 additions & 0 deletions tests/chat_completions_streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5512,3 +5512,142 @@ async fn detector_internal_server_error() -> Result<(), anyhow::Error> {

Ok(())
}

#[test(tokio::test)]
async fn fast_detector_race_condition() -> Result<(), anyhow::Error> {
// This test verifies the fix for the race condition where detectors respond faster
// than the LLM stream inserts completions into shared state.

let mut openai_server = MockServer::new_http("openai");
openai_server.mock(|when, then| {
when.post()
.path(CHAT_COMPLETIONS_ENDPOINT)
.json(json!({
"stream": true,
"model": "test-0B",
"messages": [
Message { role: Role::User, content: Some(Content::Text("Hello!".into())), ..Default::default()},
]
})
);
then.text_stream(sse([
ChatCompletionChunk {
id: "chatcmpl-test".into(),
object: "chat.completion.chunk".into(),
created: 1749227854,
model: "test-0B".into(),
choices: vec![ChatCompletionChunkChoice {
index: 0,
delta: ChatCompletionDelta {
role: Some(Role::Assistant),
..Default::default()
},
..Default::default()
}],
..Default::default()
},
ChatCompletionChunk {
id: "chatcmpl-test".into(),
object: "chat.completion.chunk".into(),
created: 1749227854,
model: "test-0B".into(),
choices: vec![ChatCompletionChunkChoice {
index: 0,
delta: ChatCompletionDelta {
content: Some("Hi!".into()),
..Default::default()
},
..Default::default()
}],
..Default::default()
},
ChatCompletionChunk {
id: "chatcmpl-test".into(),
object: "chat.completion.chunk".into(),
created: 1749227854,
model: "test-0B".into(),
choices: vec![ChatCompletionChunkChoice {
index: 0,
finish_reason: Some("stop".into()),
..Default::default()
}],
..Default::default()
},
]));
});

let mut sentence_chunker_server = MockServer::new_grpc("sentence_chunker");
sentence_chunker_server.mock(|when, then| {
when.post()
.path(CHUNKER_STREAMING_ENDPOINT)
.header(CHUNKER_MODEL_ID_HEADER_NAME, "sentence_chunker")
.pb_stream(vec![BidiStreamingChunkerTokenizationTaskRequest {
text_stream: "Hi!".into(),
input_index_stream: 1,
}]);
then.pb_stream(vec![ChunkerTokenizationStreamResult {
results: vec![Token {
start: 0,
end: 3,
text: "Hi!".into(),
}],
token_count: 0,
processed_index: 3,
start_index: 0,
input_start_index: 1,
input_end_index: 1,
}]);
});

// Fast detector that responds immediately with no detections
let mut pii_detector_sentence_server = MockServer::new_http("pii_detector_sentence");
pii_detector_sentence_server.mock(|when, then| {
when.post()
.path(TEXT_CONTENTS_DETECTOR_ENDPOINT)
.header("detector-id", PII_DETECTOR_SENTENCE)
.json(ContentAnalysisRequest {
contents: vec!["Hi!".into()],
detector_params: DetectorParams::default(),
});
// Detector responds immediately with no detections
then.json(json!([[]]));
});

let test_server = TestOrchestratorServer::builder()
.config_path(ORCHESTRATOR_CONFIG_FILE_PATH)
.openai_server(&openai_server)
.chunker_servers([&sentence_chunker_server])
.detector_servers([&pii_detector_sentence_server])
.build()
.await?;

let response = test_server
.post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT)
.json(&json!({
"stream": true,
"model": "test-0B",
"detectors": {
"input": {},
"output": {
"pii_detector_sentence": {}
}
},
"messages": [
Message { role: Role::User, content: Some(Content::Text("Hello!".into())), ..Default::default()},
],
}))
.send()
.await?;
assert_eq!(response.status(), StatusCode::OK);

let sse_stream: SseStream<ChatCompletionChunk> = SseStream::new(response.bytes_stream());
let messages = sse_stream.try_collect::<Vec<_>>().await?;

// The key assertion: request completes successfully without panicking

assert!(messages.len() >= 2, "should complete without panicking");
assert_eq!(messages[0].choices[0].delta.role, Some(Role::Assistant));
assert!(messages.last().unwrap().choices[0].finish_reason.is_some());

Ok(())
}