|
1 | | -use base64::engine::{general_purpose, Engine}; |
2 | | -use derive_builder::Builder; |
3 | | -use serde::{Deserialize, Serialize}; |
4 | | - |
5 | | -use crate::error::OpenAIError; |
6 | | - |
7 | | -#[derive(Debug, Serialize, Clone, PartialEq, Deserialize)] |
8 | | -#[serde(untagged)] |
9 | | -pub enum EmbeddingInput { |
10 | | - String(String), |
11 | | - StringArray(Vec<String>), |
12 | | - // Minimum value is 0, maximum value is 100257 (inclusive). |
13 | | - IntegerArray(Vec<u32>), |
14 | | - ArrayOfIntegerArray(Vec<Vec<u32>>), |
15 | | -} |
16 | | - |
17 | | -#[derive(Debug, Serialize, Default, Clone, PartialEq, Deserialize)] |
18 | | -#[serde(rename_all = "lowercase")] |
19 | | -pub enum EncodingFormat { |
20 | | - #[default] |
21 | | - Float, |
22 | | - Base64, |
23 | | -} |
24 | | - |
25 | | -#[derive(Debug, Serialize, Default, Clone, Builder, PartialEq, Deserialize)] |
26 | | -#[builder(name = "CreateEmbeddingRequestArgs")] |
27 | | -#[builder(pattern = "mutable")] |
28 | | -#[builder(setter(into, strip_option), default)] |
29 | | -#[builder(derive(Debug))] |
30 | | -#[builder(build_fn(error = "OpenAIError"))] |
31 | | -pub struct CreateEmbeddingRequest { |
32 | | - /// ID of the model to use. You can use the |
33 | | - /// [List models](https://platform.openai.com/docs/api-reference/models/list) |
34 | | - /// API to see all of your available models, or see our |
35 | | - /// [Model overview](https://platform.openai.com/docs/models/overview) |
36 | | - /// for descriptions of them. |
37 | | - pub model: String, |
38 | | - |
39 | | - /// Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for `text-embedding-ada-002`), cannot be an empty string, and any array must be 2048 dimensions or less. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens. |
40 | | - pub input: EmbeddingInput, |
41 | | - |
42 | | - /// The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/). Defaults to float |
43 | | - #[serde(skip_serializing_if = "Option::is_none")] |
44 | | - pub encoding_format: Option<EncodingFormat>, |
45 | | - |
46 | | - /// A unique identifier representing your end-user, which will help OpenAI |
47 | | - /// to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/usage-policies/end-user-ids). |
48 | | - #[serde(skip_serializing_if = "Option::is_none")] |
49 | | - pub user: Option<String>, |
50 | | - |
51 | | - /// The number of dimensions the resulting output embeddings should have. Only supported in `text-embedding-3` and later models. |
52 | | - #[serde(skip_serializing_if = "Option::is_none")] |
53 | | - pub dimensions: Option<u32>, |
54 | | -} |
55 | | - |
56 | | -/// Represents an embedding vector returned by embedding endpoint. |
57 | | -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] |
58 | | -pub struct Embedding { |
59 | | - /// The index of the embedding in the list of embeddings. |
60 | | - pub index: u32, |
61 | | - /// The object type, which is always "embedding". |
62 | | - pub object: String, |
63 | | - /// The embedding vector, which is a list of floats. The length of vector |
64 | | - /// depends on the model as listed in the [embedding guide](https://platform.openai.com/docs/guides/embeddings). |
65 | | - pub embedding: Vec<f32>, |
66 | | -} |
67 | | - |
68 | | -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] |
69 | | -pub struct Base64EmbeddingVector(pub String); |
70 | | - |
71 | | -impl From<Base64EmbeddingVector> for Vec<f32> { |
72 | | - fn from(value: Base64EmbeddingVector) -> Self { |
73 | | - let bytes = general_purpose::STANDARD |
74 | | - .decode(value.0) |
75 | | - .expect("openai base64 encoding to be valid"); |
76 | | - let chunks = bytes.chunks_exact(4); |
77 | | - chunks |
78 | | - .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) |
79 | | - .collect() |
80 | | - } |
81 | | -} |
82 | | - |
83 | | -/// Represents an base64-encoded embedding vector returned by embedding endpoint. |
84 | | -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] |
85 | | -pub struct Base64Embedding { |
86 | | - /// The index of the embedding in the list of embeddings. |
87 | | - pub index: u32, |
88 | | - /// The object type, which is always "embedding". |
89 | | - pub object: String, |
90 | | - /// The embedding vector, encoded in base64. |
91 | | - pub embedding: Base64EmbeddingVector, |
92 | | -} |
93 | | - |
94 | | -#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] |
95 | | -pub struct EmbeddingUsage { |
96 | | - /// The number of tokens used by the prompt. |
97 | | - pub prompt_tokens: u32, |
98 | | - /// The total number of tokens used by the request. |
99 | | - pub total_tokens: u32, |
100 | | -} |
101 | | - |
102 | | -#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] |
103 | | -pub struct CreateEmbeddingResponse { |
104 | | - pub object: String, |
105 | | - /// The name of the model used to generate the embedding. |
106 | | - pub model: String, |
107 | | - /// The list of embeddings generated by the model. |
108 | | - pub data: Vec<Embedding>, |
109 | | - /// The usage information for the request. |
110 | | - pub usage: EmbeddingUsage, |
111 | | -} |
112 | | - |
113 | | -#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] |
114 | | -pub struct CreateBase64EmbeddingResponse { |
115 | | - pub object: String, |
116 | | - /// The name of the model used to generate the embedding. |
117 | | - pub model: String, |
118 | | - /// The list of embeddings generated by the model. |
119 | | - pub data: Vec<Base64Embedding>, |
120 | | - /// The usage information for the request. |
121 | | - pub usage: EmbeddingUsage, |
122 | | -} |
| 1 | +use base64::engine::{general_purpose, Engine}; |
| 2 | +use derive_builder::Builder; |
| 3 | +use serde::{Deserialize, Serialize}; |
| 4 | + |
| 5 | +use crate::error::OpenAIError; |
| 6 | + |
| 7 | +#[derive(Debug, Serialize, Clone, PartialEq, Deserialize)] |
| 8 | +#[serde(untagged)] |
| 9 | +pub enum EmbeddingInput { |
| 10 | + String(String), |
| 11 | + StringArray(Vec<String>), |
| 12 | + // Minimum value is 0, maximum value is 100257 (inclusive). |
| 13 | + IntegerArray(Vec<u32>), |
| 14 | + ArrayOfIntegerArray(Vec<Vec<u32>>), |
| 15 | +} |
| 16 | + |
| 17 | +#[derive(Debug, Serialize, Default, Clone, PartialEq, Deserialize)] |
| 18 | +#[serde(rename_all = "lowercase")] |
| 19 | +pub enum EncodingFormat { |
| 20 | + #[default] |
| 21 | + Float, |
| 22 | + Base64, |
| 23 | +} |
| 24 | + |
| 25 | +#[derive(Debug, Serialize, Default, Clone, Builder, PartialEq, Deserialize)] |
| 26 | +#[builder(name = "CreateEmbeddingRequestArgs")] |
| 27 | +#[builder(pattern = "mutable")] |
| 28 | +#[builder(setter(into, strip_option), default)] |
| 29 | +#[builder(derive(Debug))] |
| 30 | +#[builder(build_fn(error = "OpenAIError"))] |
| 31 | +pub struct CreateEmbeddingRequest { |
| 32 | + /// ID of the model to use. You can use the |
| 33 | + /// [List models](https://platform.openai.com/docs/api-reference/models/list) |
| 34 | + /// API to see all of your available models, or see our |
| 35 | + /// [Model overview](https://platform.openai.com/docs/models/overview) |
| 36 | + /// for descriptions of them. |
| 37 | + pub model: String, |
| 38 | + |
| 39 | + /// Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for `text-embedding-ada-002`), cannot be an empty string, and any array must be 2048 dimensions or less. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens. |
| 40 | + pub input: EmbeddingInput, |
| 41 | + |
| 42 | + /// The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/). Defaults to float |
| 43 | + #[serde(skip_serializing_if = "Option::is_none")] |
| 44 | + pub encoding_format: Option<EncodingFormat>, |
| 45 | + |
| 46 | + /// A unique identifier representing your end-user, which will help OpenAI |
| 47 | + /// to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/usage-policies/end-user-ids). |
| 48 | + #[serde(skip_serializing_if = "Option::is_none")] |
| 49 | + pub user: Option<String>, |
| 50 | + |
| 51 | + /// The number of dimensions the resulting output embeddings should have. Only supported in `text-embedding-3` and later models. |
| 52 | + #[serde(skip_serializing_if = "Option::is_none")] |
| 53 | + pub dimensions: Option<u32>, |
| 54 | +} |
| 55 | + |
| 56 | +/// Represents an embedding vector returned by embedding endpoint. |
| 57 | +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] |
| 58 | +pub struct Embedding { |
| 59 | + /// The index of the embedding in the list of embeddings. |
| 60 | + pub index: u32, |
| 61 | + /// The object type, which is always "embedding". |
| 62 | + pub object: String, |
| 63 | + /// The embedding vector, which is a list of floats. The length of vector |
| 64 | + /// depends on the model as listed in the [embedding guide](https://platform.openai.com/docs/guides/embeddings). |
| 65 | + pub embedding: Vec<f32>, |
| 66 | +} |
| 67 | + |
| 68 | +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] |
| 69 | +pub struct Base64EmbeddingVector(pub String); |
| 70 | + |
| 71 | +impl From<Base64EmbeddingVector> for Vec<f32> { |
| 72 | + fn from(value: Base64EmbeddingVector) -> Self { |
| 73 | + let bytes = general_purpose::STANDARD |
| 74 | + .decode(value.0) |
| 75 | + .expect("openai base64 encoding to be valid"); |
| 76 | + let chunks = bytes.chunks_exact(4); |
| 77 | + chunks |
| 78 | + .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) |
| 79 | + .collect() |
| 80 | + } |
| 81 | +} |
| 82 | + |
| 83 | +/// Represents an base64-encoded embedding vector returned by embedding endpoint. |
| 84 | +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] |
| 85 | +pub struct Base64Embedding { |
| 86 | + /// The index of the embedding in the list of embeddings. |
| 87 | + pub index: u32, |
| 88 | + /// The object type, which is always "embedding". |
| 89 | + pub object: String, |
| 90 | + /// The embedding vector, encoded in base64. |
| 91 | + pub embedding: Base64EmbeddingVector, |
| 92 | +} |
| 93 | + |
| 94 | +#[derive(Debug, Deserialize, Serialize, Clone, PartialEq)] |
| 95 | +pub struct EmbeddingUsage { |
| 96 | + /// The number of tokens used by the prompt. |
| 97 | + pub prompt_tokens: u32, |
| 98 | + /// The total number of tokens used by the request. |
| 99 | + pub total_tokens: u32, |
| 100 | +} |
| 101 | + |
| 102 | +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] |
| 103 | +pub struct CreateEmbeddingResponse { |
| 104 | + pub object: String, |
| 105 | + /// The name of the model used to generate the embedding. |
| 106 | + pub model: String, |
| 107 | + /// The list of embeddings generated by the model. |
| 108 | + pub data: Vec<Embedding>, |
| 109 | + /// The usage information for the request. |
| 110 | + pub usage: EmbeddingUsage, |
| 111 | +} |
| 112 | + |
| 113 | +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] |
| 114 | +pub struct CreateBase64EmbeddingResponse { |
| 115 | + pub object: String, |
| 116 | + /// The name of the model used to generate the embedding. |
| 117 | + pub model: String, |
| 118 | + /// The list of embeddings generated by the model. |
| 119 | + pub data: Vec<Base64Embedding>, |
| 120 | + /// The usage information for the request. |
| 121 | + pub usage: EmbeddingUsage, |
| 122 | +} |
0 commit comments