Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions config/application.rb
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,6 @@ class Application < Rails::Application
config.bigquery_dataset_id = ENV["BIGQUERY_DATASET"]

config.answer_strategy = ENV.fetch("ANSWER_STRATEGY", "openai_structured_answer")
config.embedding_provider = ENV.fetch("EMBEDDING_PROVIDER", "openai")
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def id_digests
end

def index_chunks(indexable_chunks)
embeddings = Search::TextToEmbedding.call(indexable_chunks.map(&:plain_content))
embeddings = Search::TextToEmbedding.call(indexable_chunks.map(&:plain_content), llm_provider: :openai)

created = 0
updated = 0
Expand Down
11 changes: 9 additions & 2 deletions lib/search/chunked_content_repository.rb
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,21 @@ def id_digest_hash(base_path, batch_size: 100)
items
end

def search_by_embedding(embedding, max_chunks:)
def search_by_embedding(embedding, max_chunks:, llm_provider:)
field_name = case llm_provider.to_sym
when :openai
:openai_embedding
else
raise "Unknown provider: #{llm_provider}"
end

response = client.search(
index:,
body: {
size: max_chunks,
query: {
knn: {
openai_embedding: {
"#{field_name}": {
vector: embedding,
k: max_chunks,
},
Expand Down
11 changes: 9 additions & 2 deletions lib/search/results_for_question.rb
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,21 @@ def self.call(question_message)
max_results = Rails.configuration.search.thresholds.max_results
max_chunks = Rails.configuration.search.thresholds.retrieved_from_index

provider = Rails.configuration.embedding_provider

metrics = {}
embedding_start_time = Clock.monotonic_time
embedding = Search::TextToEmbedding.call(question_message)
embedding = Search::TextToEmbedding.call(question_message, llm_provider: provider)
metrics[:embedding_duration] = Clock.monotonic_time - embedding_start_time

search_start_time = Clock.monotonic_time
results = ChunkedContentRepository.new.search_by_embedding(embedding, max_chunks:)
results = ChunkedContentRepository.new.search_by_embedding(
embedding,
max_chunks:,
llm_provider: provider,
)
metrics[:search_duration] = Clock.monotonic_time - search_start_time
metrics[:embedding_provider] = provider

reranking_start_time = Clock.monotonic_time
weighted_results = Search::ResultsForQuestion::Reranker.call(results)
Expand Down
56 changes: 6 additions & 50 deletions lib/search/text_to_embedding.rb
Original file line number Diff line number Diff line change
@@ -1,56 +1,12 @@
module Search
class TextToEmbedding
EMBEDDING_MODEL = "text-embedding-3-large".freeze
INPUT_TOKEN_LIMIT = 8191
BATCH_SIZE = 50

def self.call(...) = new(...).call

def initialize(single_or_collection_of_text)
@string_input = single_or_collection_of_text.is_a?(String)
@text_collection = Array(single_or_collection_of_text)
end

def call
to_embed = text_collection.map(&method(:keep_input_within_token_limit))

embeddings = convert_text_to_embeddings(to_embed)

# return just first embedding rather than an array of embeddings if we
# weren't given an array input
string_input ? embeddings.first : embeddings
end

private

attr_reader :string_input, :text_collection

def openai_client
@openai_client ||= OpenAIClient.build
end

def keep_input_within_token_limit(text)
as_tokens = token_encoder.encode(text)

return text if as_tokens.length <= INPUT_TOKEN_LIMIT

token_encoder.decode(as_tokens[...INPUT_TOKEN_LIMIT])
end

def convert_text_to_embeddings(to_embed_collection)
batches = to_embed_collection.each_slice(BATCH_SIZE).to_a

batches.flat_map do |batch|
response = openai_client.embeddings(
parameters: { model: EMBEDDING_MODEL, input: batch },
)

response["data"].map { |data| data["embedding"] }
def self.call(single_or_collection_of_text, llm_provider:)
case llm_provider.to_sym
when :openai
Search::TextToEmbedding::OpenAI.call(single_or_collection_of_text)
else
raise "Unknown provider: #{llm_provider}"
end
end

def token_encoder
@token_encoder ||= Tiktoken.encoding_for_model(EMBEDDING_MODEL)
end
end
end
56 changes: 56 additions & 0 deletions lib/search/text_to_embedding/openai.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
class Search::TextToEmbedding
class OpenAI
EMBEDDING_MODEL = "text-embedding-3-large".freeze
INPUT_TOKEN_LIMIT = 8191
BATCH_SIZE = 50

def self.call(...) = new(...).call

def initialize(single_or_collection_of_text)
@string_input = single_or_collection_of_text.is_a?(String)
@text_collection = Array(single_or_collection_of_text)
end

def call
to_embed = text_collection.map(&method(:keep_input_within_token_limit))

embeddings = convert_text_to_embeddings(to_embed)

# return just first embedding rather than an array of embeddings if we
# weren't given an array input
string_input ? embeddings.first : embeddings
end

private

attr_reader :string_input, :text_collection

def openai_client
@openai_client ||= OpenAIClient.build
end

def keep_input_within_token_limit(text)
as_tokens = token_encoder.encode(text)

return text if as_tokens.length <= INPUT_TOKEN_LIMIT

token_encoder.decode(as_tokens[...INPUT_TOKEN_LIMIT])
end

def convert_text_to_embeddings(to_embed_collection)
batches = to_embed_collection.each_slice(BATCH_SIZE).to_a

batches.flat_map do |batch|
response = openai_client.embeddings(
parameters: { model: EMBEDDING_MODEL, input: batch },
)

response["data"].map { |data| data["embedding"] }
end
end

def token_encoder
@token_encoder ||= Tiktoken.encoding_for_model(EMBEDDING_MODEL)
end
end
end
2 changes: 1 addition & 1 deletion lib/tasks/search.rake
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ namespace :search do
end
end

embeddings = Search::TextToEmbedding.call(chunks.map(&:plain_content))
embeddings = Search::TextToEmbedding.call(chunks.map(&:plain_content), llm_provider: :openai)
repository = Search::ChunkedContentRepository.new
indexed = 0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
end

it "applies OpenAI embedding to the data going into the search index" do
allow(Search::TextToEmbedding).to receive(:call).and_call_original
allow(Search::TextToEmbedding::OpenAI).to receive(:call).and_call_original

expect { described_class.call(content_item, repository) }
.to change { repository.count(exists: { field: :openai_embedding }) }
.by(chunks.length)

expect(Search::TextToEmbedding).to have_received(:call)
expect(Search::TextToEmbedding::OpenAI).to have_received(:call)
end

it "returns a Result object" do
Expand Down
24 changes: 21 additions & 3 deletions spec/lib/search/chunked_content_repository_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,11 @@
end

it "returns an array of Result objects" do
result = repository.search_by_embedding(openai_embedding, max_chunks: 10)
result = repository.search_by_embedding(
openai_embedding,
max_chunks: 10,
llm_provider: :openai,
)
expected_attributes = chunked_content_records.first
.except(:openai_embedding)
.merge(score: a_value_between(0.9, 1))
Expand All @@ -229,12 +233,26 @@
expect(result.first).to have_attributes(**expected_attributes)
end

context "when there are more then the maxiumum chunks" do
it "raises an error if the llm provider is not recognised" do
expect {
repository.search_by_embedding(
openai_embedding,
max_chunks: 10,
llm_provider: :unknown,
)
}.to raise_error("Unknown provider: unknown")
end

context "when there are more than the maxiumum chunks" do
let(:max_chunks) { 10 }
let(:chunked_content_records) { build_list(:chunked_content_record, 11, openai_embedding:) }

it "only returns the first max_chunks" do
result = repository.search_by_embedding(openai_embedding, max_chunks:)
result = repository.search_by_embedding(
openai_embedding,
max_chunks:,
llm_provider: :openai,
)
expect(result.count).to eq max_chunks
end
end
Expand Down
10 changes: 6 additions & 4 deletions spec/lib/search/results_for_question_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
before do
allow(Search::TextToEmbedding)
.to receive(:call)
.with(question_message)
.with(question_message, llm_provider: "openai")
.and_return(openai_embedding)

allow(Rails.configuration.search.thresholds).to receive_messages(minimum_score: min_score, max_results:)
Expand All @@ -25,9 +25,11 @@
end

it "retrieves an embedding for the question_message and searches the chunked content repository" do
allow(Rails.configuration).to receive(:embedding_provider).and_return("openai")

result = described_class.call(question_message)
expect(result).to be_a(Search::ResultsForQuestion::ResultSet)
expect(Search::TextToEmbedding).to have_received(:call).with(question_message)
expect(Search::TextToEmbedding).to have_received(:call).with(question_message, llm_provider: "openai")
end

it "has the results over the configured threshold after reranking" do
Expand All @@ -45,9 +47,9 @@
["not found 2", a_value_between(0.2, 0.3)])
end

it "populates the metrics attribute with the durations of the embedding, search, and reranking steps" do
it "populates the metrics attribute" do
result = described_class.call(question_message)
expect(result.metrics).to eq({ embedding_duration: 1.5, search_duration: 2.0, reranking_duration: 1.0 })
expect(result.metrics).to eq({ embedding_duration: 1.5, search_duration: 2.0, reranking_duration: 1.0, embedding_provider: "openai" })
end

context "when then are more results than the configured max_results" do
Expand Down
55 changes: 55 additions & 0 deletions spec/lib/search/text_to_embedding/open_ai_spec.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
RSpec.describe Search::TextToEmbedding::OpenAI do
describe ".call" do
it "returns a single embedding array for a string input" do
to_embed = "text"
stub_openai_embedding(to_embed)

embedding = described_class.call(to_embed)

expect(embedding)
.to be_an_instance_of(Array)
.and have_attributes(length: Search::ChunkedContentRepository::OPENAI_EMBEDDING_DIMENSIONS)
end

it "returns an array of embedding arrays for an array input" do
to_embed = ["Embed this", "Embed that"]
stub_openai_embedding(to_embed)

embedding_collection = described_class.call(to_embed)

expect(embedding_collection)
.to be_an_instance_of(Array)
.and have_attributes(length: 2)
.and all(have_attributes(length: Search::ChunkedContentRepository::OPENAI_EMBEDDING_DIMENSIONS))
end

it "does multiple requests to OpenAI when the number of strings is greater than the batch size" do
input_1 = Array.new(described_class::BATCH_SIZE, "to embed")
input_2 = Array.new(5, "to embed in a second request")

request_1 = stub_openai_embedding(input_1)
request_2 = stub_openai_embedding(input_2)

described_class.call(input_1 + input_2)

expect(request_1).to have_been_made
expect(request_2).to have_been_made
end

it "truncates input that exceeds the token limit to avoid a context length exceeded error" do
very_long_input = "test " * 10_000

encoder = Tiktoken.encoding_for_model(described_class::EMBEDDING_MODEL)

request = stub_any_openai_embedding(embeddings_per_request: 1).with do |req|
input = JSON.parse(req.body).dig("input", 0)
input_tokens = encoder.encode(input)
input.match(/test\s/) && input_tokens.length == described_class::INPUT_TOKEN_LIMIT
end

described_class.call(very_long_input)

expect(request).to have_been_made
end
end
end
Loading