Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

121 changes: 63 additions & 58 deletions crates/rig-bedrock/src/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
use crate::{
client::Client,
types::{
assistant_content::AwsConverseOutput, completion_request::AwsCompletionRequest,
converse_output::InternalConverseOutput, errors::AwsSdkConverseError,
assistant_content::{AwsConverseOutput, completion_response_events},
completion_request::AwsCompletionRequest,
converse_output::InternalConverseOutput,
errors::AwsSdkConverseError,
},
};

use rig_core::completion::{self, CompletionError, CompletionRequest};
use rig_core::streaming::StreamingCompletionResponse;
use rig_core::model_event::ModelEventStream;
use rig_core::telemetry::SpanCombinator;
use tracing::Instrument;

Expand Down Expand Up @@ -219,77 +221,80 @@ impl completion::CompletionModel for CompletionModel {
Self::new(client.clone(), model)
}

async fn completion(
async fn events(
&self,
completion_request: completion::CompletionRequest,
) -> Result<completion::CompletionResponse<AwsConverseOutput>, CompletionError> {
let request_model = resolve_request_model(&self.model, &completion_request);
) -> Result<rig_core::model_event::ModelEventStream<Self::Response>, CompletionError> {
async {
let request_model = resolve_request_model(&self.model, &completion_request);

let span = if tracing::Span::current().is_disabled() {
tracing::info_span!(
target: "rig_core::completions",
"chat",
gen_ai.operation.name = "chat",
gen_ai.provider.name = "aws_bedrock",
gen_ai.request.model = &request_model,
gen_ai.system_instructions = &completion_request.preamble,
gen_ai.response.id = tracing::field::Empty,
gen_ai.response.model = tracing::field::Empty,
gen_ai.usage.output_tokens = tracing::field::Empty,
gen_ai.usage.input_tokens = tracing::field::Empty,
gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
)
} else {
tracing::Span::current()
};
let span = if tracing::Span::current().is_disabled() {
tracing::info_span!(
target: "rig_core::completions",
"chat",
gen_ai.operation.name = "chat",
gen_ai.provider.name = "aws_bedrock",
gen_ai.request.model = &request_model,
gen_ai.system_instructions = &completion_request.preamble,
gen_ai.response.id = tracing::field::Empty,
gen_ai.response.model = tracing::field::Empty,
gen_ai.usage.output_tokens = tracing::field::Empty,
gen_ai.usage.input_tokens = tracing::field::Empty,
gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
)
} else {
tracing::Span::current()
};

let request = AwsCompletionRequest {
inner: completion_request,
prompt_caching: self.prompt_caching,
};
let request = AwsCompletionRequest {
inner: completion_request,
prompt_caching: self.prompt_caching,
};

let mut converse_builder = self
.client
.get_inner()
.await
.converse()
.model_id(request_model.clone());
let mut converse_builder = self
.client
.get_inner()
.await
.converse()
.model_id(request_model.clone());

let tool_config = request.tools_config()?;
let messages = request.messages()?;
converse_builder = converse_builder
.set_additional_model_request_fields(request.additional_params())
.set_inference_config(request.inference_config())
.set_tool_config(tool_config)
.set_system(request.system_prompt()?)
.set_messages(Some(messages));
let tool_config = request.tools_config()?;
let messages = request.messages()?;
converse_builder = converse_builder
.set_additional_model_request_fields(request.additional_params())
.set_inference_config(request.inference_config())
.set_tool_config(tool_config)
.set_system(request.system_prompt()?)
.set_messages(Some(messages));

async move {
let response = converse_builder.send().await.map_err(|sdk_error| {
Into::<CompletionError>::into(AwsSdkConverseError(sdk_error))
})?;
async move {
let response = converse_builder.send().await.map_err(|sdk_error| {
Into::<CompletionError>::into(AwsSdkConverseError(sdk_error))
})?;

let response: InternalConverseOutput = response.try_into().map_err(|x| {
CompletionError::ProviderError(format!("Type conversion error: {x}"))
})?;
let response: InternalConverseOutput = response.try_into().map_err(|x| {
CompletionError::ProviderError(format!("Type conversion error: {x}"))
})?;

let aws_output = AwsConverseOutput(response);
let aws_output = AwsConverseOutput(response);

let span = tracing::Span::current();
span.record_response_metadata(&aws_output);
span.record_token_usage(&aws_output);
let span = tracing::Span::current();
span.record_response_metadata(&aws_output);
span.record_token_usage(&aws_output);

aws_output.try_into()
completion_response_events(aws_output)
}
.instrument(span)
.await
}
.instrument(span)
.await
}

async fn stream(
async fn stream_events(
&self,
request: CompletionRequest,
) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
CompletionModel::stream(self, request).await
) -> Result<ModelEventStream<Self::StreamingResponse>, CompletionError> {
CompletionModel::stream_events(self, request).await
}
}
99 changes: 54 additions & 45 deletions crates/rig-bedrock/src/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@ use crate::{
};
use async_stream::stream;
use aws_sdk_bedrockruntime::types as aws_bedrock;
use rig_core::completion::GetTokenUsage;
use rig_core::streaming::StreamingCompletionResponse;
use rig_core::message::{Reasoning, ToolCall, ToolFunction};
use rig_core::{
completion::CompletionError,
message::ReasoningContent,
streaming::{RawStreamingChoice, RawStreamingToolCall, ToolCallDeltaContent},
completion::{CompletionError, GetTokenUsage},
model_event::{ModelEvent, ModelEventStream, ToolCallDeltaContent},
};
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -65,26 +63,19 @@ struct ReasoningState {
/// Field required` when the conversation is replayed to Bedrock. We must emit
/// whenever either the content or the signature is present; both-empty is
/// still skipped.
fn finalize_reasoning(
state: ReasoningState,
) -> Option<RawStreamingChoice<BedrockStreamingResponse>> {
fn finalize_reasoning(state: ReasoningState) -> Option<ModelEvent<BedrockStreamingResponse>> {
if state.content.is_empty() && state.signature.is_none() {
return None;
}
Some(RawStreamingChoice::Reasoning {
id: None,
content: ReasoningContent::Text {
text: state.content,
signature: state.signature,
},
})
let reasoning = Reasoning::new_with_signature(&state.content, state.signature);
Some(ModelEvent::ReasoningDone { reasoning })
}

impl CompletionModel {
pub(crate) async fn stream(
pub(crate) async fn stream_events(
&self,
completion_request: rig_core::completion::CompletionRequest,
) -> Result<StreamingCompletionResponse<BedrockStreamingResponse>, CompletionError> {
) -> Result<ModelEventStream<BedrockStreamingResponse>, CompletionError> {
let request_model = resolve_request_model(&self.model, &completion_request);
let request = AwsCompletionRequest {
inner: completion_request,
Expand Down Expand Up @@ -122,7 +113,7 @@ impl CompletionModel {
match delta {
aws_bedrock::ContentBlockDelta::Text(text) => {
if current_tool_call.is_none() {
yield Ok(RawStreamingChoice::Message(text))
yield Ok(ModelEvent::TextDelta { text })
}
},
aws_bedrock::ContentBlockDelta::ToolUse(tool) => {
Expand All @@ -131,7 +122,7 @@ impl CompletionModel {
tool_call.input_json.push_str(&delta);

// Emit the delta so UI can show progress
yield Ok(RawStreamingChoice::ToolCallDelta {
yield Ok(ModelEvent::ToolCallDelta {
id: tool_call.id.clone(),
internal_call_id: tool_call.internal_call_id.clone(),
content: ToolCallDeltaContent::Delta(delta),
Expand All @@ -150,8 +141,8 @@ impl CompletionModel {
}

if !text.is_empty() {
yield Ok(RawStreamingChoice::ReasoningDelta {
reasoning: text.clone(),
yield Ok(ModelEvent::ReasoningDelta {
text: text.clone(),
id: None,
})
}
Expand Down Expand Up @@ -181,7 +172,7 @@ impl CompletionModel {
internal_call_id: internal_call_id.clone(),
input_json: String::new(),
});
yield Ok(RawStreamingChoice::ToolCallDelta {
yield Ok(ModelEvent::ToolCallDelta {
id: tool_use.tool_use_id,
internal_call_id,
content: ToolCallDeltaContent::Name(tool_use.name),
Expand All @@ -206,10 +197,19 @@ impl CompletionModel {
} else {
serde_json::from_str(tool_call.input_json.as_str())?
};
yield Ok(RawStreamingChoice::ToolCall(
RawStreamingToolCall::new(tool_call.id, tool_call.name, tool_input)
.with_internal_call_id(tool_call.internal_call_id)
));
yield Ok(ModelEvent::ToolCallDone {
tool_call: ToolCall {
id: tool_call.id,
call_id: None,
function: ToolFunction {
name: tool_call.name,
arguments: tool_input,
},
signature: None,
additional_params: None,
},
internal_call_id: Some(tool_call.internal_call_id),
});
} else {
yield Err(CompletionError::ProviderError("Failed to call tool".into()))
}
Expand All @@ -223,29 +223,35 @@ impl CompletionModel {
aws_bedrock::ConverseStreamOutput::Metadata(metadata_event) => {
// Extract usage information from metadata
if let Some(usage) = metadata_event.usage {
yield Ok(RawStreamingChoice::FinalResponse(BedrockStreamingResponse {
let response = BedrockStreamingResponse {
usage: Some(BedrockUsage {
input_tokens: usage.input_tokens,
output_tokens: usage.output_tokens,
total_tokens: usage.total_tokens,
cache_read_input_tokens: usage.cache_read_input_tokens,
cache_write_input_tokens: usage.cache_write_input_tokens,
}),
}));
};
if let Some(usage) = response.token_usage() {
yield Ok(ModelEvent::Usage { usage });
}
yield Ok(ModelEvent::RawResponse { response });
yield Ok(ModelEvent::Done);
}
},
_ => {}
}
}
});

Ok(StreamingCompletionResponse::stream(stream))
Ok(rig_core::completion::codec::result_stream(stream))
}
}

#[cfg(test)]
mod tests {
use super::*;
use rig_core::message::ReasoningContent;

#[test]
fn test_bedrock_usage_creation() {
Expand Down Expand Up @@ -556,17 +562,20 @@ mod tests {

let choice = finalize_reasoning(state).expect("should emit reasoning");
match choice {
RawStreamingChoice::Reasoning { id, content } => {
assert!(id.is_none());
match content {
ReasoningContent::Text { text, signature } => {
assert_eq!(text, "I am thinking");
assert_eq!(signature.as_deref(), Some("sig-abc"));
}
other => panic!("expected ReasoningContent::Text, got {:?}", other),
}
ModelEvent::ReasoningDone { reasoning } => {
assert!(reasoning.id.is_none());
match reasoning.content.as_slice() {
[content] => match content {
ReasoningContent::Text { text, signature } => {
assert_eq!(text, "I am thinking");
assert_eq!(signature.as_deref(), Some("sig-abc"));
}
other => panic!("expected ReasoningContent::Text, got {:?}", other),
},
other => panic!("expected one reasoning content block, got {:?}", other),
};
}
_ => panic!("expected RawStreamingChoice::Reasoning"),
_ => panic!("expected ModelEvent::ReasoningDone"),
}
}

Expand All @@ -583,14 +592,14 @@ mod tests {
let choice =
finalize_reasoning(state).expect("should emit reasoning for signature-only state");
match choice {
RawStreamingChoice::Reasoning { content, .. } => match content {
ReasoningContent::Text { text, signature } => {
ModelEvent::ReasoningDone { reasoning } => match reasoning.content.as_slice() {
[ReasoningContent::Text { text, signature }] => {
assert!(text.is_empty());
assert_eq!(signature.as_deref(), Some("sig-only"));
}
other => panic!("expected ReasoningContent::Text, got {:?}", other),
},
_ => panic!("expected RawStreamingChoice::Reasoning"),
_ => panic!("expected ModelEvent::ReasoningDone"),
}
}

Expand All @@ -604,14 +613,14 @@ mod tests {
let choice =
finalize_reasoning(state).expect("should emit reasoning for content-only state");
match choice {
RawStreamingChoice::Reasoning { content, .. } => match content {
ReasoningContent::Text { text, signature } => {
ModelEvent::ReasoningDone { reasoning } => match reasoning.content.as_slice() {
[ReasoningContent::Text { text, signature }] => {
assert_eq!(text, "thoughts without sig");
assert!(signature.is_none());
}
other => panic!("expected ReasoningContent::Text, got {:?}", other),
},
_ => panic!("expected RawStreamingChoice::Reasoning"),
_ => panic!("expected ModelEvent::ReasoningDone"),
}
}

Expand Down
Loading
Loading