Skip to content

Commit ef2dba9

Browse files
committed
Add class for embedding with Titan
Adds a `TextToEmbedding::Titan` class that calls out to Titan via Bedrock to embed a string or array of strings of text. Works in exactly the same way that the current OpenAI one does, and the wrapper `TextToEmbedding` class will call this class if the env var is set to use Titan as the embedding provider.
1 parent dc07489 commit ef2dba9

6 files changed

Lines changed: 124 additions & 6 deletions

File tree

lib/bedrock_models.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
module BedrockModels
22
CLAUDE_3_7_SONNET = "eu.anthropic.claude-3-7-sonnet-20250219-v1:0".freeze
3+
TITAN_EMBED_V2 = "amazon.titan-embed-text-v2:0".freeze
34
end

lib/search/text_to_embedding.rb

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ def self.call(single_or_collection_of_text, llm_provider:)
44
case llm_provider.to_sym
55
when :openai
66
Search::TextToEmbedding::OpenAI.call(single_or_collection_of_text)
7+
when :titan
8+
Search::TextToEmbedding::Titan.call(single_or_collection_of_text)
79
else
810
raise "Unknown provider: #{llm_provider}"
911
end
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
class Search::TextToEmbedding
2+
class Titan
3+
INPUT_TEXT_LENGTH_LIMIT = 50_000
4+
5+
def self.call(...) = new(...).call
6+
7+
def initialize(single_or_collection_of_text)
8+
@string_input = single_or_collection_of_text.is_a?(String)
9+
@text_collection = Array(single_or_collection_of_text)
10+
end
11+
12+
def call
13+
# For strings longer than the embedding model limit we have to truncate
14+
# the text.
15+
# This is done silently and in future we may want to log this in a
16+
# database or log a warning.
17+
to_embed = text_collection.map(&method(:keep_input_within_text_limit))
18+
19+
embeddings = convert_text_to_embeddings(to_embed)
20+
21+
# return just first embedding rather than an array of embeddings if we
22+
# weren't given an array input
23+
string_input ? embeddings.first : embeddings
24+
end
25+
26+
private
27+
28+
attr_reader :string_input, :text_collection
29+
30+
def bedrock_client
31+
@bedrock_client ||= Aws::BedrockRuntime::Client.new
32+
end
33+
34+
def keep_input_within_text_limit(text)
35+
text[0...INPUT_TEXT_LENGTH_LIMIT]
36+
end
37+
38+
def convert_text_to_embeddings(to_embed_collection)
39+
to_embed_collection.map do |text|
40+
response = bedrock_client.invoke_model(
41+
model_id: BedrockModels::TITAN_EMBED_V2,
42+
body: {
43+
inputText: text,
44+
}.to_json,
45+
)
46+
47+
JSON.parse(response.body.read)["embedding"]
48+
end
49+
end
50+
end
51+
end
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
RSpec.describe Search::TextToEmbedding::Titan do
2+
describe ".call" do
3+
it "returns a single embedding array for a string input" do
4+
client = stub_bedrock_invoke_model(
5+
bedrock_titan_embedding_response([1.0, 2.0, 3.0]),
6+
)
7+
8+
embedding = described_class.call("text")
9+
10+
expect(client.api_requests.size).to eq(1)
11+
12+
expect(embedding).to eq([1.0, 2.0, 3.0])
13+
end
14+
15+
it "returns an array of embedding arrays for an array input" do
16+
client = stub_bedrock_invoke_model(
17+
bedrock_titan_embedding_response([1.0, 2.0, 3.0]),
18+
bedrock_titan_embedding_response([4.0, 5.0, 6.0]),
19+
)
20+
21+
embedding = described_class.call(["Embed this", "Embed that"])
22+
23+
expect(client.api_requests.size).to eq(2)
24+
25+
expect(embedding).to eq([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
26+
end
27+
28+
it "truncates input text to the length limit" do
29+
client = stub_bedrock_invoke_model(
30+
bedrock_titan_embedding_response([1.0, 2.0, 3.0]),
31+
)
32+
33+
long_text = "a" * (described_class::INPUT_TEXT_LENGTH_LIMIT + 1)
34+
described_class.call(long_text)
35+
36+
expect(client.api_requests.size).to eq(1)
37+
38+
request_body = JSON.parse(
39+
client.api_requests.first.dig(:params, :body),
40+
)
41+
42+
expect(request_body["inputText"].length)
43+
.to eq(described_class::INPUT_TEXT_LENGTH_LIMIT)
44+
end
45+
end
46+
end

spec/lib/search/text_to_embedding_spec.rb

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
11
RSpec.describe Search::TextToEmbedding do
22
describe ".call" do
33
let(:text) { "The text" }
4-
let(:provider) { :openai }
54

65
it "calls the OpenAI embedding provider" do
76
expect(Search::TextToEmbedding::OpenAI).to receive(:call).with(text)
8-
described_class.call(text, llm_provider: provider)
7+
described_class.call(text, llm_provider: :openai)
98
end
109

11-
context "when an unknown provider is specified" do
12-
let(:provider) { :unknown_provider }
10+
it "calls the Titan embedding provider" do
11+
expect(Search::TextToEmbedding::Titan).to receive(:call).with(text)
12+
described_class.call(text, llm_provider: :titan)
13+
end
1314

15+
context "when an unknown provider is specified" do
1416
it "raises an error" do
15-
expect { described_class.call(text, llm_provider: provider) }
16-
.to raise_error(RuntimeError, "Unknown provider: #{provider}")
17+
expect { described_class.call(text, llm_provider: "notreal") }
18+
.to raise_error(RuntimeError, "Unknown provider: notreal")
1719
end
1820
end
1921
end

spec/support/stub_bedrock.rb

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@ def stub_bedrock_converse(*responses)
2828
bedrock_client
2929
end
3030

31+
def stub_bedrock_invoke_model(*responses)
32+
bedrock_client = Aws::BedrockRuntime::Client.new(stub_responses: true)
33+
allow(Aws::BedrockRuntime::Client).to receive(:new).and_return(bedrock_client)
34+
bedrock_client.stub_responses(:invoke_model, responses)
35+
bedrock_client
36+
end
37+
3138
def bedrock_claude_structured_answer_response(question, answer, answered: true)
3239
lambda do |context|
3340
given_question = context.params.dig(:messages, -1, :content, 0, :text)
@@ -88,6 +95,15 @@ def bedrock_claude_guardrail_response(triggered: false, triggered_guardrails: []
8895
end
8996
end
9097

98+
def bedrock_titan_embedding_response(embedding_array)
99+
{
100+
content_type: "application/json",
101+
body: {
102+
embedding: embedding_array,
103+
}.to_json,
104+
}
105+
end
106+
91107
def bedrock_claude_text_response(response_text,
92108
user_message: nil,
93109
input_tokens: 10,

0 commit comments

Comments
 (0)