Skip to content

Commit 779f6b3

Browse files
Add Bedrock integration tests (0xPlaygrounds#1707)
* Add Bedrock integration tests * tests * tests
1 parent 91e5d97 commit 779f6b3

10 files changed

Lines changed: 525 additions & 0 deletions

File tree

tests/integrations.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
clippy::unreachable
77
)]
88

9+
#[cfg(feature = "bedrock")]
10+
#[path = "integrations/bedrock/mod.rs"]
11+
mod bedrock;
912
#[cfg(feature = "lancedb")]
1013
#[path = "integrations/lancedb/mod.rs"]
1114
mod lancedb;
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
//! Live Bedrock Anthropic adaptive-thinking regression tests.
2+
3+
use futures::StreamExt;
4+
use rig::agent::AgentBuilder;
5+
use rig::client::CompletionClient;
6+
use rig::completion::{CompletionModel as _, Prompt};
7+
use rig::streaming::StreamedAssistantContent;
8+
use serde_json::json;
9+
10+
use super::{
11+
anthropic_adaptive_model, anthropic_signature_only_model, client,
12+
support::{ALPHA_SIGNAL_OUTPUT, AlphaSignal, assert_contains_all_case_insensitive},
13+
};
14+
15+
fn adaptive_thinking_params() -> serde_json::Value {
16+
json!({
17+
"thinking": {
18+
"type": "adaptive"
19+
}
20+
})
21+
}
22+
23+
#[tokio::test]
24+
#[ignore = "requires AWS credentials and Bedrock Anthropic adaptive-thinking model access"]
25+
async fn adaptive_thinking_prompt_caching_tool_roundtrip_regression() {
26+
let model = client()
27+
.completion_model(anthropic_adaptive_model())
28+
.with_prompt_caching();
29+
let agent = AgentBuilder::new(model)
30+
.preamble(
31+
"You must call tools when the user asks for their result. \
32+
After a tool result is available, answer with the exact result.",
33+
)
34+
.max_tokens(2048)
35+
.additional_params(adaptive_thinking_params())
36+
.tool(AlphaSignal)
37+
.build();
38+
39+
let response = agent
40+
.prompt("Call `lookup_harbor_label` exactly once, then answer with the exact tool output.")
41+
.await
42+
.expect("adaptive-thinking prompt-caching tool roundtrip should succeed");
43+
44+
assert_contains_all_case_insensitive(&response, &[ALPHA_SIGNAL_OUTPUT]);
45+
}
46+
47+
#[tokio::test]
48+
#[ignore = "requires AWS credentials and Bedrock Anthropic adaptive-thinking model access"]
49+
async fn streaming_emits_signature_only_adaptive_reasoning_regression() {
50+
let model = client().completion_model(anthropic_signature_only_model());
51+
let request = model
52+
.completion_request("What is 2 + 2? Answer with only the number.")
53+
.max_tokens(2048)
54+
.additional_params(adaptive_thinking_params())
55+
.build();
56+
let mut stream = model
57+
.stream(request)
58+
.await
59+
.expect("adaptive-thinking Bedrock stream should start");
60+
61+
let mut reasoning_chunks = 0;
62+
let mut signature_chunks = 0;
63+
let mut signature_only_chunks = 0;
64+
let mut got_final = false;
65+
66+
while let Some(item) = stream.next().await {
67+
match item.expect("adaptive-thinking Bedrock stream item should succeed") {
68+
StreamedAssistantContent::Reasoning(reasoning) => {
69+
reasoning_chunks += 1;
70+
if reasoning.first_signature().is_some() {
71+
signature_chunks += 1;
72+
if reasoning.display_text().is_empty() {
73+
signature_only_chunks += 1;
74+
}
75+
}
76+
}
77+
StreamedAssistantContent::Final(_) => got_final = true,
78+
_ => {}
79+
}
80+
}
81+
82+
assert!(got_final, "stream should emit a final response");
83+
assert!(
84+
reasoning_chunks > 0,
85+
"expected at least one adaptive-thinking reasoning chunk"
86+
);
87+
assert!(
88+
signature_chunks > 0,
89+
"expected adaptive-thinking reasoning to include a Bedrock signature"
90+
);
91+
assert!(
92+
signature_only_chunks > 0,
93+
"expected at least one signature-only reasoning chunk"
94+
);
95+
}
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
//! AWS Bedrock completion smoke tests inspired by the OpenAI and Anthropic provider tests.
2+
3+
use rig::agent::AgentBuilder;
4+
use rig::client::CompletionClient;
5+
use rig::completion::Prompt;
6+
7+
use super::{
8+
BEDROCK_COMPLETION_MODEL, client,
9+
support::{
10+
Adder, BASIC_PREAMBLE, BASIC_PROMPT, CONTEXT_DOCS, CONTEXT_PROMPT,
11+
STREAMING_TOOLS_PREAMBLE, STREAMING_TOOLS_PROMPT, Subtract,
12+
assert_contains_any_case_insensitive, assert_mentions_expected_number,
13+
assert_nonempty_response,
14+
},
15+
};
16+
17+
#[tokio::test]
18+
#[ignore = "requires AWS credentials and Bedrock model access"]
19+
async fn completion_smoke() {
20+
let agent = client()
21+
.agent(BEDROCK_COMPLETION_MODEL)
22+
.preamble(BASIC_PREAMBLE)
23+
.build();
24+
25+
let response = agent
26+
.prompt(BASIC_PROMPT)
27+
.await
28+
.expect("completion should succeed");
29+
30+
assert_nonempty_response(&response);
31+
}
32+
33+
#[tokio::test]
34+
#[ignore = "requires AWS credentials and Bedrock model access"]
35+
async fn completion_with_context_smoke() {
36+
let agent = client()
37+
.agent(BEDROCK_COMPLETION_MODEL)
38+
.preamble("Answer the user using only the supplied context.")
39+
.context(CONTEXT_DOCS[0])
40+
.context(CONTEXT_DOCS[1])
41+
.context(CONTEXT_DOCS[2])
42+
.build();
43+
44+
let response = agent
45+
.prompt(CONTEXT_PROMPT)
46+
.await
47+
.expect("context completion should succeed");
48+
49+
assert_contains_any_case_insensitive(&response, &["ancient tool", "farm"]);
50+
}
51+
52+
#[tokio::test]
53+
#[ignore = "requires AWS credentials and Bedrock model access"]
54+
async fn tool_roundtrip_smoke() {
55+
let agent = client()
56+
.agent(BEDROCK_COMPLETION_MODEL)
57+
.preamble(STREAMING_TOOLS_PREAMBLE)
58+
.max_tokens(1024)
59+
.tool(Adder)
60+
.tool(Subtract)
61+
.build();
62+
63+
let response = agent
64+
.prompt(STREAMING_TOOLS_PROMPT)
65+
.await
66+
.expect("tool prompt should succeed");
67+
68+
assert_mentions_expected_number(&response, -3);
69+
}
70+
71+
#[tokio::test]
72+
#[ignore = "requires AWS credentials and Bedrock model access"]
73+
async fn prompt_caching_completion_smoke() {
74+
let bedrock_client = client();
75+
let model = bedrock_client
76+
.completion_model(BEDROCK_COMPLETION_MODEL)
77+
.with_prompt_caching();
78+
let agent = AgentBuilder::new(model).preamble(BASIC_PREAMBLE).build();
79+
80+
let response = agent
81+
.prompt(BASIC_PROMPT)
82+
.await
83+
.expect("prompt-caching completion should succeed");
84+
85+
assert_nonempty_response(&response);
86+
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
//! AWS Bedrock document prompt smoke tests inspired by Anthropic document tests.
2+
3+
use rig::OneOrMany;
4+
use rig::client::CompletionClient;
5+
use rig::completion::Prompt;
6+
use rig::message::{Document, DocumentMediaType, DocumentSourceKind, Message, UserContent};
7+
8+
use super::{
9+
BEDROCK_COMPLETION_MODEL, client,
10+
support::{assert_contains_any_case_insensitive, assert_nonempty_response},
11+
};
12+
13+
fn rust_document() -> String {
14+
r#"
15+
The Rust Programming Language
16+
17+
Rust is a systems programming language focused on three goals: safety, speed,
18+
and concurrency. It accomplishes these goals without a garbage collector.
19+
20+
Key Features:
21+
- Zero-cost abstractions
22+
- Move semantics
23+
- Guaranteed memory safety
24+
- Threads without data races
25+
"#
26+
.trim()
27+
.to_string()
28+
}
29+
30+
#[tokio::test]
31+
#[ignore = "requires AWS credentials and Bedrock model access"]
32+
async fn plaintext_document_prompt() {
33+
let agent = client()
34+
.agent(BEDROCK_COMPLETION_MODEL)
35+
.preamble("Summarize the provided document.")
36+
.temperature(0.5)
37+
.build();
38+
39+
let document = Document {
40+
data: DocumentSourceKind::String(rust_document()),
41+
media_type: Some(DocumentMediaType::TXT),
42+
additional_params: None,
43+
};
44+
let response = agent
45+
.prompt(document)
46+
.await
47+
.expect("document prompt should succeed");
48+
49+
assert_nonempty_response(&response);
50+
assert_contains_any_case_insensitive(&response, &["safety", "speed", "concurrency"]);
51+
}
52+
53+
#[tokio::test]
54+
#[ignore = "requires AWS credentials and Bedrock model access"]
55+
async fn plaintext_document_with_instruction() {
56+
let agent = client()
57+
.agent(BEDROCK_COMPLETION_MODEL)
58+
.preamble("Answer from the provided document.")
59+
.temperature(0.5)
60+
.build();
61+
62+
let response = agent
63+
.prompt(Message::User {
64+
content: OneOrMany::many(vec![
65+
UserContent::document(rust_document(), Some(DocumentMediaType::TXT)),
66+
UserContent::text("List the three main goals of Rust mentioned in this document."),
67+
])
68+
.expect("content should be non-empty"),
69+
})
70+
.await
71+
.expect("instruction prompt should succeed");
72+
73+
assert_contains_any_case_insensitive(&response, &["safety", "speed", "concurrency"]);
74+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//! AWS Bedrock embedding smoke test inspired by provider embedding coverage.
2+
3+
use rig::client::EmbeddingsClient;
4+
use rig::embeddings::EmbeddingModel as _;
5+
6+
use super::{
7+
BEDROCK_EMBEDDING_MODEL, client,
8+
support::{EMBEDDING_INPUTS, assert_embeddings_nonempty_and_consistent},
9+
};
10+
11+
#[tokio::test]
12+
#[ignore = "requires AWS credentials and Bedrock embedding model access"]
13+
async fn embeddings_smoke() {
14+
let model = client().embedding_model_with_ndims(BEDROCK_EMBEDDING_MODEL, 256);
15+
let embeddings = model
16+
.embed_texts(EMBEDDING_INPUTS.into_iter().map(str::to_string))
17+
.await
18+
.expect("embedding request should succeed");
19+
20+
assert_embeddings_nonempty_and_consistent(&embeddings, EMBEDDING_INPUTS.len());
21+
assert!(
22+
embeddings
23+
.iter()
24+
.all(|embedding| embedding.vec.len() == 256),
25+
"Titan text embeddings v2 should return the requested 256 dimensions"
26+
);
27+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
//! AWS Bedrock extractor smoke tests inspired by the provider extractor tests.
2+
3+
use rig::client::CompletionClient;
4+
use rig::message::Message;
5+
6+
use super::{
7+
BEDROCK_COMPLETION_MODEL, client,
8+
support::{EXTRACTOR_TEXT, SmokePerson, assert_nonempty_response},
9+
};
10+
11+
fn assert_smoke_person(person: &SmokePerson) {
12+
let first_name = person
13+
.first_name
14+
.as_deref()
15+
.expect("first_name should be present");
16+
let last_name = person
17+
.last_name
18+
.as_deref()
19+
.expect("last_name should be present");
20+
let job = person.job.as_deref().expect("job should be present");
21+
22+
assert_nonempty_response(first_name);
23+
assert_nonempty_response(last_name);
24+
assert_nonempty_response(job);
25+
}
26+
27+
#[tokio::test]
28+
#[ignore = "requires AWS credentials and Bedrock model access"]
29+
async fn extractor_smoke() {
30+
let extractor = client()
31+
.extractor::<SmokePerson>(BEDROCK_COMPLETION_MODEL)
32+
.build();
33+
34+
let response = extractor
35+
.extract_with_usage(EXTRACTOR_TEXT)
36+
.await
37+
.expect("extractor request should succeed");
38+
39+
assert_smoke_person(&response.data);
40+
assert!(response.usage.total_tokens > 0, "usage should be populated");
41+
}
42+
43+
#[tokio::test]
44+
#[ignore = "requires AWS credentials and Bedrock model access"]
45+
async fn extractor_with_chat_history_smoke() {
46+
let extractor = client()
47+
.extractor::<SmokePerson>(BEDROCK_COMPLETION_MODEL)
48+
.build();
49+
50+
let response = extractor
51+
.extract_with_chat_history_with_usage(
52+
"The text is about Ada Lovelace, a mathematician.",
53+
vec![Message::user(
54+
"Extract the person's name and job from the next message.",
55+
)],
56+
)
57+
.await
58+
.expect("extractor request with chat history should succeed");
59+
60+
assert_smoke_person(&response.data);
61+
assert!(response.usage.total_tokens > 0, "usage should be populated");
62+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
//! AWS Bedrock image generation smoke test inspired by OpenAI image generation tests.
2+
3+
use rig::image_generation::ImageGenerationModel;
4+
use rig::prelude::ImageGenerationClient;
5+
6+
use super::{
7+
BEDROCK_IMAGE_MODEL, client,
8+
support::{IMAGE_PROMPT, assert_nonempty_bytes},
9+
};
10+
11+
#[tokio::test]
12+
#[ignore = "requires AWS credentials and Bedrock image generation model access"]
13+
async fn image_generation_smoke() {
14+
let model = client().image_generation_model(BEDROCK_IMAGE_MODEL);
15+
let response = model
16+
.image_generation_request()
17+
.prompt(IMAGE_PROMPT)
18+
.width(512)
19+
.height(512)
20+
.send()
21+
.await
22+
.expect("image generation request should succeed");
23+
24+
assert_nonempty_bytes(&response.image);
25+
}

0 commit comments

Comments
 (0)