Skip to content

Commit 9e670ad

Browse files
committed
Merge branch 'main' into release-0.18.3
2 parents 2b70c04 + b37c4bf commit 9e670ad

File tree

2 files changed

+194
-6
lines changed

2 files changed

+194
-6
lines changed

src/orchestrator/handlers/chat_completions_detection/streaming.rs

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
limitations under the License.
1515
1616
*/
17-
use std::{collections::HashMap, sync::Arc};
17+
use std::{collections::HashMap, sync::Arc, time::Duration};
1818

1919
use futures::{StreamExt, future::try_join_all};
2020
use opentelemetry::trace::TraceId;
@@ -39,6 +39,14 @@ use crate::{
3939
},
4040
};
4141

42+
/// Timeout duration when waiting for completion state entries to become available.
43+
/// This handles the race condition where detectors respond faster than the LLM stream inserts completions.
44+
///
45+
/// The waiting mechanism uses cooperative yielding via `tokio::task::yield_now()`, not busy waiting.
46+
/// Each yield returns immediately in the normal case (microseconds), so this timeout only matters
47+
/// if the generation server is exceptionally slow or experiencing issues.
48+
const COMPLETION_WAIT_TIMEOUT: Duration = Duration::from_secs(30);
49+
4250
pub async fn handle_streaming(
4351
ctx: Arc<Context>,
4452
task: ChatCompletionsDetectionTask,
@@ -563,14 +571,33 @@ async fn handle_whole_doc_detection(
563571
}
564572

565573
/// Builds a response with output detections.
566-
fn output_detection_response(
574+
async fn output_detection_response(
567575
completion_state: &Arc<CompletionState<ChatCompletionChunk>>,
568576
choice_index: u32,
569577
chunk: Chunk,
570578
detections: Vec<Detection>,
571579
) -> Result<ChatCompletionChunk, Error> {
572-
// Get chat completions for this choice index
573-
let chat_completions = completion_state.completions.get(&choice_index).unwrap();
580+
// Wait for entry to exist (yields to other tasks until ready)
581+
let chat_completions = {
582+
let wait_for_entry = async {
583+
loop {
584+
if let Some(entry) = completion_state.completions.get(&choice_index) {
585+
return entry;
586+
}
587+
tokio::task::yield_now().await;
588+
}
589+
};
590+
591+
match tokio::time::timeout(COMPLETION_WAIT_TIMEOUT, wait_for_entry).await {
592+
Ok(entry) => entry,
593+
Err(_) => {
594+
return Err(Error::Other(format!(
595+
"completion entry for choice_index {} not ready after {:?} timeout",
596+
choice_index, COMPLETION_WAIT_TIMEOUT
597+
)));
598+
}
599+
}
600+
};
574601
// Get range of chat completions for this chunk
575602
let chat_completions = chat_completions
576603
.range(chunk.input_start_index..=chunk.input_end_index)
@@ -709,6 +736,7 @@ async fn process_detection_batch_stream(
709736
Ok((choice_index, chunk, detections)) => {
710737
let input_end_index = chunk.input_end_index;
711738
match output_detection_response(&completion_state, choice_index, chunk, detections)
739+
.await
712740
{
713741
Ok(chat_completion) => {
714742
// Send chat completion to response channel
@@ -718,8 +746,29 @@ async fn process_detection_batch_stream(
718746
return;
719747
}
720748
// If this is the final chat completion chunk with content, send chat completion chunk with finish reason
721-
let chat_completions =
722-
completion_state.completions.get(&choice_index).unwrap();
749+
// Wait for entry to exist (yields to other tasks until ready)
750+
let chat_completions = {
751+
let wait_for_entry = async {
752+
loop {
753+
if let Some(entry) =
754+
completion_state.completions.get(&choice_index)
755+
{
756+
return entry;
757+
}
758+
tokio::task::yield_now().await;
759+
}
760+
};
761+
762+
match tokio::time::timeout(COMPLETION_WAIT_TIMEOUT, wait_for_entry)
763+
.await
764+
{
765+
Ok(entry) => entry,
766+
Err(_) => {
767+
error!(%trace_id, %choice_index, "completion entry not ready after {:?} timeout", COMPLETION_WAIT_TIMEOUT);
768+
return;
769+
}
770+
}
771+
};
723772
if chat_completions.keys().rev().nth(1) == Some(&input_end_index)
724773
&& let Some((_, chat_completion)) = chat_completions.last_key_value()
725774
&& chat_completion

tests/chat_completions_streaming.rs

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5512,3 +5512,142 @@ async fn detector_internal_server_error() -> Result<(), anyhow::Error> {
55125512

55135513
Ok(())
55145514
}
5515+
5516+
#[test(tokio::test)]
5517+
async fn fast_detector_race_condition() -> Result<(), anyhow::Error> {
5518+
// This test verifies the fix for the race condition where detectors respond faster
5519+
// than the LLM stream inserts completions into shared state.
5520+
5521+
let mut openai_server = MockServer::new_http("openai");
5522+
openai_server.mock(|when, then| {
5523+
when.post()
5524+
.path(CHAT_COMPLETIONS_ENDPOINT)
5525+
.json(json!({
5526+
"stream": true,
5527+
"model": "test-0B",
5528+
"messages": [
5529+
Message { role: Role::User, content: Some(Content::Text("Hello!".into())), ..Default::default()},
5530+
]
5531+
})
5532+
);
5533+
then.text_stream(sse([
5534+
ChatCompletionChunk {
5535+
id: "chatcmpl-test".into(),
5536+
object: "chat.completion.chunk".into(),
5537+
created: 1749227854,
5538+
model: "test-0B".into(),
5539+
choices: vec![ChatCompletionChunkChoice {
5540+
index: 0,
5541+
delta: ChatCompletionDelta {
5542+
role: Some(Role::Assistant),
5543+
..Default::default()
5544+
},
5545+
..Default::default()
5546+
}],
5547+
..Default::default()
5548+
},
5549+
ChatCompletionChunk {
5550+
id: "chatcmpl-test".into(),
5551+
object: "chat.completion.chunk".into(),
5552+
created: 1749227854,
5553+
model: "test-0B".into(),
5554+
choices: vec![ChatCompletionChunkChoice {
5555+
index: 0,
5556+
delta: ChatCompletionDelta {
5557+
content: Some("Hi!".into()),
5558+
..Default::default()
5559+
},
5560+
..Default::default()
5561+
}],
5562+
..Default::default()
5563+
},
5564+
ChatCompletionChunk {
5565+
id: "chatcmpl-test".into(),
5566+
object: "chat.completion.chunk".into(),
5567+
created: 1749227854,
5568+
model: "test-0B".into(),
5569+
choices: vec![ChatCompletionChunkChoice {
5570+
index: 0,
5571+
finish_reason: Some("stop".into()),
5572+
..Default::default()
5573+
}],
5574+
..Default::default()
5575+
},
5576+
]));
5577+
});
5578+
5579+
let mut sentence_chunker_server = MockServer::new_grpc("sentence_chunker");
5580+
sentence_chunker_server.mock(|when, then| {
5581+
when.post()
5582+
.path(CHUNKER_STREAMING_ENDPOINT)
5583+
.header(CHUNKER_MODEL_ID_HEADER_NAME, "sentence_chunker")
5584+
.pb_stream(vec![BidiStreamingChunkerTokenizationTaskRequest {
5585+
text_stream: "Hi!".into(),
5586+
input_index_stream: 1,
5587+
}]);
5588+
then.pb_stream(vec![ChunkerTokenizationStreamResult {
5589+
results: vec![Token {
5590+
start: 0,
5591+
end: 3,
5592+
text: "Hi!".into(),
5593+
}],
5594+
token_count: 0,
5595+
processed_index: 3,
5596+
start_index: 0,
5597+
input_start_index: 1,
5598+
input_end_index: 1,
5599+
}]);
5600+
});
5601+
5602+
// Fast detector that responds immediately with no detections
5603+
let mut pii_detector_sentence_server = MockServer::new_http("pii_detector_sentence");
5604+
pii_detector_sentence_server.mock(|when, then| {
5605+
when.post()
5606+
.path(TEXT_CONTENTS_DETECTOR_ENDPOINT)
5607+
.header("detector-id", PII_DETECTOR_SENTENCE)
5608+
.json(ContentAnalysisRequest {
5609+
contents: vec!["Hi!".into()],
5610+
detector_params: DetectorParams::default(),
5611+
});
5612+
// Detector responds immediately with no detections
5613+
then.json(json!([[]]));
5614+
});
5615+
5616+
let test_server = TestOrchestratorServer::builder()
5617+
.config_path(ORCHESTRATOR_CONFIG_FILE_PATH)
5618+
.openai_server(&openai_server)
5619+
.chunker_servers([&sentence_chunker_server])
5620+
.detector_servers([&pii_detector_sentence_server])
5621+
.build()
5622+
.await?;
5623+
5624+
let response = test_server
5625+
.post(ORCHESTRATOR_CHAT_COMPLETIONS_DETECTION_ENDPOINT)
5626+
.json(&json!({
5627+
"stream": true,
5628+
"model": "test-0B",
5629+
"detectors": {
5630+
"input": {},
5631+
"output": {
5632+
"pii_detector_sentence": {}
5633+
}
5634+
},
5635+
"messages": [
5636+
Message { role: Role::User, content: Some(Content::Text("Hello!".into())), ..Default::default()},
5637+
],
5638+
}))
5639+
.send()
5640+
.await?;
5641+
assert_eq!(response.status(), StatusCode::OK);
5642+
5643+
let sse_stream: SseStream<ChatCompletionChunk> = SseStream::new(response.bytes_stream());
5644+
let messages = sse_stream.try_collect::<Vec<_>>().await?;
5645+
5646+
// The key assertion: request completes successfully without panicking
5647+
5648+
assert!(messages.len() >= 2, "should complete without panicking");
5649+
assert_eq!(messages[0].choices[0].delta.role, Some(Role::Assistant));
5650+
assert!(messages.last().unwrap().choices[0].finish_reason.is_some());
5651+
5652+
Ok(())
5653+
}

0 commit comments

Comments
 (0)