|
1 | 1 | module Search |
2 | 2 | class TextToEmbedding |
3 | | - EMBEDDING_MODEL = "text-embedding-3-large".freeze |
4 | | - INPUT_TOKEN_LIMIT = 8191 |
5 | | - BATCH_SIZE = 50 |
6 | | - |
7 | | - def self.call(...) = new(...).call |
8 | | - |
9 | | - def initialize(single_or_collection_of_text) |
10 | | - @string_input = single_or_collection_of_text.is_a?(String) |
11 | | - @text_collection = Array(single_or_collection_of_text) |
12 | | - end |
13 | | - |
14 | | - def call |
15 | | - to_embed = text_collection.map(&method(:keep_input_within_token_limit)) |
16 | | - |
17 | | - embeddings = convert_text_to_embeddings(to_embed) |
18 | | - |
19 | | - # return just first embedding rather than an array of embeddings if we |
20 | | - # weren't given an array input |
21 | | - string_input ? embeddings.first : embeddings |
22 | | - end |
23 | | - |
24 | | - private |
25 | | - |
26 | | - attr_reader :string_input, :text_collection |
27 | | - |
28 | | - def openai_client |
29 | | - @openai_client ||= OpenAIClient.build |
30 | | - end |
31 | | - |
32 | | - def keep_input_within_token_limit(text) |
33 | | - as_tokens = token_encoder.encode(text) |
34 | | - |
35 | | - return text if as_tokens.length <= INPUT_TOKEN_LIMIT |
36 | | - |
37 | | - token_encoder.decode(as_tokens[...INPUT_TOKEN_LIMIT]) |
38 | | - end |
39 | | - |
40 | | - def convert_text_to_embeddings(to_embed_collection) |
41 | | - batches = to_embed_collection.each_slice(BATCH_SIZE).to_a |
42 | | - |
43 | | - batches.flat_map do |batch| |
44 | | - response = openai_client.embeddings( |
45 | | - parameters: { model: EMBEDDING_MODEL, input: batch }, |
46 | | - ) |
47 | | - |
48 | | - response["data"].map { |data| data["embedding"] } |
| 3 | + def self.call(single_or_collection_of_text, llm_provider:) |
| 4 | + case llm_provider.to_sym |
| 5 | + when :openai |
| 6 | + Search::TextToEmbedding::OpenAI.call(single_or_collection_of_text) |
| 7 | + else |
| 8 | + raise "Unknown provider: #{llm_provider}" |
49 | 9 | end |
50 | 10 | end |
51 | | - |
52 | | - def token_encoder |
53 | | - @token_encoder ||= Tiktoken.encoding_for_model(EMBEDDING_MODEL) |
54 | | - end |
55 | 11 | end |
56 | 12 | end |
0 commit comments