Skip to content

Commit faeeeaf

Browse files
authored
Merge pull request #207 from alphagov/embedding-model-provider
Namespace OpenAI embedding class
2 parents 2e1e8d9 + b6a09dd commit faeeeaf

13 files changed

Lines changed: 178 additions & 111 deletions

File tree

config/application.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,5 +106,6 @@ class Application < Rails::Application
106106
config.bigquery_dataset_id = ENV["BIGQUERY_DATASET"]
107107

108108
config.answer_strategy = ENV.fetch("ANSWER_STRATEGY", "openai_structured_answer")
109+
config.embedding_provider = ENV.fetch("EMBEDDING_PROVIDER", "openai")
109110
end
110111
end

lib/message_queue/content_synchroniser/index_content_item.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def id_digests
3838
end
3939

4040
def index_chunks(indexable_chunks)
41-
embeddings = Search::TextToEmbedding.call(indexable_chunks.map(&:plain_content))
41+
embeddings = Search::TextToEmbedding.call(indexable_chunks.map(&:plain_content), llm_provider: :openai)
4242

4343
created = 0
4444
updated = 0

lib/search/chunked_content_repository.rb

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,14 +141,21 @@ def id_digest_hash(base_path, batch_size: 100)
141141
items
142142
end
143143

144-
def search_by_embedding(embedding, max_chunks:)
144+
def search_by_embedding(embedding, max_chunks:, llm_provider:)
145+
field_name = case llm_provider.to_sym
146+
when :openai
147+
:openai_embedding
148+
else
149+
raise "Unknown provider: #{llm_provider}"
150+
end
151+
145152
response = client.search(
146153
index:,
147154
body: {
148155
size: max_chunks,
149156
query: {
150157
knn: {
151-
openai_embedding: {
158+
"#{field_name}": {
152159
vector: embedding,
153160
k: max_chunks,
154161
},

lib/search/results_for_question.rb

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,21 @@ def self.call(question_message)
77
max_results = Rails.configuration.search.thresholds.max_results
88
max_chunks = Rails.configuration.search.thresholds.retrieved_from_index
99

10+
provider = Rails.configuration.embedding_provider
11+
1012
metrics = {}
1113
embedding_start_time = Clock.monotonic_time
12-
embedding = Search::TextToEmbedding.call(question_message)
14+
embedding = Search::TextToEmbedding.call(question_message, llm_provider: provider)
1315
metrics[:embedding_duration] = Clock.monotonic_time - embedding_start_time
1416

1517
search_start_time = Clock.monotonic_time
16-
results = ChunkedContentRepository.new.search_by_embedding(embedding, max_chunks:)
18+
results = ChunkedContentRepository.new.search_by_embedding(
19+
embedding,
20+
max_chunks:,
21+
llm_provider: provider,
22+
)
1723
metrics[:search_duration] = Clock.monotonic_time - search_start_time
24+
metrics[:embedding_provider] = provider
1825

1926
reranking_start_time = Clock.monotonic_time
2027
weighted_results = Search::ResultsForQuestion::Reranker.call(results)

lib/search/text_to_embedding.rb

Lines changed: 6 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,12 @@
11
module Search
22
class TextToEmbedding
3-
EMBEDDING_MODEL = "text-embedding-3-large".freeze
4-
INPUT_TOKEN_LIMIT = 8191
5-
BATCH_SIZE = 50
6-
7-
def self.call(...) = new(...).call
8-
9-
def initialize(single_or_collection_of_text)
10-
@string_input = single_or_collection_of_text.is_a?(String)
11-
@text_collection = Array(single_or_collection_of_text)
12-
end
13-
14-
def call
15-
to_embed = text_collection.map(&method(:keep_input_within_token_limit))
16-
17-
embeddings = convert_text_to_embeddings(to_embed)
18-
19-
# return just first embedding rather than an array of embeddings if we
20-
# weren't given an array input
21-
string_input ? embeddings.first : embeddings
22-
end
23-
24-
private
25-
26-
attr_reader :string_input, :text_collection
27-
28-
def openai_client
29-
@openai_client ||= OpenAIClient.build
30-
end
31-
32-
def keep_input_within_token_limit(text)
33-
as_tokens = token_encoder.encode(text)
34-
35-
return text if as_tokens.length <= INPUT_TOKEN_LIMIT
36-
37-
token_encoder.decode(as_tokens[...INPUT_TOKEN_LIMIT])
38-
end
39-
40-
def convert_text_to_embeddings(to_embed_collection)
41-
batches = to_embed_collection.each_slice(BATCH_SIZE).to_a
42-
43-
batches.flat_map do |batch|
44-
response = openai_client.embeddings(
45-
parameters: { model: EMBEDDING_MODEL, input: batch },
46-
)
47-
48-
response["data"].map { |data| data["embedding"] }
3+
def self.call(single_or_collection_of_text, llm_provider:)
4+
case llm_provider.to_sym
5+
when :openai
6+
Search::TextToEmbedding::OpenAI.call(single_or_collection_of_text)
7+
else
8+
raise "Unknown provider: #{llm_provider}"
499
end
5010
end
51-
52-
def token_encoder
53-
@token_encoder ||= Tiktoken.encoding_for_model(EMBEDDING_MODEL)
54-
end
5511
end
5612
end
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
class Search::TextToEmbedding
2+
class OpenAI
3+
EMBEDDING_MODEL = "text-embedding-3-large".freeze
4+
INPUT_TOKEN_LIMIT = 8191
5+
BATCH_SIZE = 50
6+
7+
def self.call(...) = new(...).call
8+
9+
def initialize(single_or_collection_of_text)
10+
@string_input = single_or_collection_of_text.is_a?(String)
11+
@text_collection = Array(single_or_collection_of_text)
12+
end
13+
14+
def call
15+
to_embed = text_collection.map(&method(:keep_input_within_token_limit))
16+
17+
embeddings = convert_text_to_embeddings(to_embed)
18+
19+
# return just first embedding rather than an array of embeddings if we
20+
# weren't given an array input
21+
string_input ? embeddings.first : embeddings
22+
end
23+
24+
private
25+
26+
attr_reader :string_input, :text_collection
27+
28+
def openai_client
29+
@openai_client ||= OpenAIClient.build
30+
end
31+
32+
def keep_input_within_token_limit(text)
33+
as_tokens = token_encoder.encode(text)
34+
35+
return text if as_tokens.length <= INPUT_TOKEN_LIMIT
36+
37+
token_encoder.decode(as_tokens[...INPUT_TOKEN_LIMIT])
38+
end
39+
40+
def convert_text_to_embeddings(to_embed_collection)
41+
batches = to_embed_collection.each_slice(BATCH_SIZE).to_a
42+
43+
batches.flat_map do |batch|
44+
response = openai_client.embeddings(
45+
parameters: { model: EMBEDDING_MODEL, input: batch },
46+
)
47+
48+
response["data"].map { |data| data["embedding"] }
49+
end
50+
end
51+
52+
def token_encoder
53+
@token_encoder ||= Tiktoken.encoding_for_model(EMBEDDING_MODEL)
54+
end
55+
end
56+
end

lib/tasks/search.rake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ namespace :search do
6060
end
6161
end
6262

63-
embeddings = Search::TextToEmbedding.call(chunks.map(&:plain_content))
63+
embeddings = Search::TextToEmbedding.call(chunks.map(&:plain_content), llm_provider: :openai)
6464
repository = Search::ChunkedContentRepository.new
6565
indexed = 0
6666

spec/lib/message_queue/content_synchroniser/index_content_item_spec.rb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,13 @@
2525
end
2626

2727
it "applies OpenAI embedding to the data going into the search index" do
28-
allow(Search::TextToEmbedding).to receive(:call).and_call_original
28+
allow(Search::TextToEmbedding::OpenAI).to receive(:call).and_call_original
2929

3030
expect { described_class.call(content_item, repository) }
3131
.to change { repository.count(exists: { field: :openai_embedding }) }
3232
.by(chunks.length)
3333

34-
expect(Search::TextToEmbedding).to have_received(:call)
34+
expect(Search::TextToEmbedding::OpenAI).to have_received(:call)
3535
end
3636

3737
it "returns a Result object" do

spec/lib/search/chunked_content_repository_spec.rb

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,11 @@
220220
end
221221

222222
it "returns an array of Result objects" do
223-
result = repository.search_by_embedding(openai_embedding, max_chunks: 10)
223+
result = repository.search_by_embedding(
224+
openai_embedding,
225+
max_chunks: 10,
226+
llm_provider: :openai,
227+
)
224228
expected_attributes = chunked_content_records.first
225229
.except(:openai_embedding)
226230
.merge(score: a_value_between(0.9, 1))
@@ -229,12 +233,26 @@
229233
expect(result.first).to have_attributes(**expected_attributes)
230234
end
231235

232-
context "when there are more then the maxiumum chunks" do
236+
it "raises an error if the llm provider is not recognised" do
237+
expect {
238+
repository.search_by_embedding(
239+
openai_embedding,
240+
max_chunks: 10,
241+
llm_provider: :unknown,
242+
)
243+
}.to raise_error("Unknown provider: unknown")
244+
end
245+
246+
context "when there are more than the maxiumum chunks" do
233247
let(:max_chunks) { 10 }
234248
let(:chunked_content_records) { build_list(:chunked_content_record, 11, openai_embedding:) }
235249

236250
it "only returns the first max_chunks" do
237-
result = repository.search_by_embedding(openai_embedding, max_chunks:)
251+
result = repository.search_by_embedding(
252+
openai_embedding,
253+
max_chunks:,
254+
llm_provider: :openai,
255+
)
238256
expect(result.count).to eq max_chunks
239257
end
240258
end

spec/lib/search/results_for_question_spec.rb

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
before do
99
allow(Search::TextToEmbedding)
1010
.to receive(:call)
11-
.with(question_message)
11+
.with(question_message, llm_provider: "openai")
1212
.and_return(openai_embedding)
1313

1414
allow(Rails.configuration.search.thresholds).to receive_messages(minimum_score: min_score, max_results:)
@@ -25,9 +25,11 @@
2525
end
2626

2727
it "retrieves an embedding for the question_message and searches the chunked content repository" do
28+
allow(Rails.configuration).to receive(:embedding_provider).and_return("openai")
29+
2830
result = described_class.call(question_message)
2931
expect(result).to be_a(Search::ResultsForQuestion::ResultSet)
30-
expect(Search::TextToEmbedding).to have_received(:call).with(question_message)
32+
expect(Search::TextToEmbedding).to have_received(:call).with(question_message, llm_provider: "openai")
3133
end
3234

3335
it "has the results over the configured threshold after reranking" do
@@ -45,9 +47,9 @@
4547
["not found 2", a_value_between(0.2, 0.3)])
4648
end
4749

48-
it "populates the metrics attribute with the durations of the embedding, search, and reranking steps" do
50+
it "populates the metrics attribute" do
4951
result = described_class.call(question_message)
50-
expect(result.metrics).to eq({ embedding_duration: 1.5, search_duration: 2.0, reranking_duration: 1.0 })
52+
expect(result.metrics).to eq({ embedding_duration: 1.5, search_duration: 2.0, reranking_duration: 1.0, embedding_provider: "openai" })
5153
end
5254

5355
context "when then are more results than the configured max_results" do

0 commit comments

Comments
 (0)