diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingOptions.java index 439114dc7e..f1baed2010 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingOptions.java @@ -42,6 +42,11 @@ public class OpenAiEmbeddingOptions extends AbstractOpenAiOptions implements Emb */ private @Nullable String user; + /** + * The format to return the embeddings in. Can be either float or base64. + */ + private EmbeddingCreateParams.@Nullable EncodingFormat encodingFormat; + /* * The number of dimensions the resulting output embeddings should have. Only * supported in `text-embedding-3` and later models. @@ -60,6 +65,14 @@ public void setUser(@Nullable String user) { this.user = user; } + public EmbeddingCreateParams.@Nullable EncodingFormat getEncodingFormat() { + return this.encodingFormat; + } + + public void setEncodingFormat(EmbeddingCreateParams.@Nullable EncodingFormat encodingFormat) { + this.encodingFormat = encodingFormat; + } + @Override public @Nullable Integer getDimensions() { return this.dimensions; @@ -72,7 +85,8 @@ public void setDimensions(@Nullable Integer dimensions) { @Override public String toString() { return "OpenAiEmbeddingOptions{" + "user='" + this.user + '\'' + ", model='" + this.getModel() + '\'' - + ", deploymentName='" + this.getDeploymentName() + '\'' + ", dimensions=" + this.dimensions + '}'; + + ", deploymentName='" + this.getDeploymentName() + '\'' + ", encodingFormat='" + this.encodingFormat + + '\'' + ", dimensions=" + this.dimensions + '}'; } public EmbeddingCreateParams toOpenAiCreateParams(List instructions) { @@ -94,6 +108,9 @@ else if (this.getModel() != null) { if (this.getUser() != null) { builder.user(this.getUser()); } + if (this.getEncodingFormat() != null) { + builder.encodingFormat(this.getEncodingFormat()); + } if (this.getDimensions() != null) { builder.dimensions(this.getDimensions()); } @@ -121,6 +138,7 @@ public Builder from(OpenAiEmbeddingOptions fromOptions) { this.options.setCustomHeaders(fromOptions.getCustomHeaders()); // Child class fields this.options.setUser(fromOptions.getUser()); + this.options.setEncodingFormat(fromOptions.getEncodingFormat()); this.options.setDimensions(fromOptions.getDimensions()); return this; } @@ -164,6 +182,9 @@ public Builder merge(@Nullable EmbeddingOptions from) { if (castFrom.getUser() != null) { this.options.setUser(castFrom.getUser()); } + if (castFrom.getEncodingFormat() != null) { + this.options.setEncodingFormat(castFrom.getEncodingFormat()); + } if (castFrom.getDimensions() != null) { this.options.setDimensions(castFrom.getDimensions()); } @@ -176,6 +197,9 @@ public Builder from(EmbeddingCreateParams openAiCreateParams) { if (openAiCreateParams.user().isPresent()) { this.options.setUser(openAiCreateParams.user().get()); } + if (openAiCreateParams.encodingFormat().isPresent()) { + this.options.setEncodingFormat(openAiCreateParams.encodingFormat().get()); + } if (openAiCreateParams.dimensions().isPresent()) { this.options.setDimensions(Math.toIntExact(openAiCreateParams.dimensions().get())); } @@ -187,6 +211,11 @@ public Builder user(String user) { return this; } + public Builder encodingFormat(EmbeddingCreateParams.EncodingFormat encodingFormat) { + this.options.setEncodingFormat(encodingFormat); + return this; + } + public Builder deploymentName(String deploymentName) { this.options.setDeploymentName(deploymentName); return this; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiEmbeddingOptionsTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiEmbeddingOptionsTests.java new file mode 100644 index 0000000000..2d832c05d1 --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiEmbeddingOptionsTests.java @@ -0,0 +1,64 @@ +/* + * Copyright 2023-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai; + +import java.util.List; + +import com.openai.models.embeddings.EmbeddingCreateParams; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +class OpenAiEmbeddingOptionsTests { + + @Test + void defaultEncodingFormatIsNull() { + OpenAiEmbeddingOptions options = OpenAiEmbeddingOptions.builder().model("test-model").build(); + + EmbeddingCreateParams createParams = options.toOpenAiCreateParams(List.of("test input")); + + assertThat(options.getEncodingFormat()).isNull(); + assertThat(createParams.encodingFormat()).contains(EmbeddingCreateParams.EncodingFormat.BASE64); + } + + @Test + void encodingFormatCanBeConfigured() { + OpenAiEmbeddingOptions options = OpenAiEmbeddingOptions.builder() + .model("test-model") + .encodingFormat(EmbeddingCreateParams.EncodingFormat.FLOAT) + .build(); + + EmbeddingCreateParams createParams = options.toOpenAiCreateParams(List.of("test input")); + + assertThat(createParams.encodingFormat()).contains(EmbeddingCreateParams.EncodingFormat.FLOAT); + } + + @Test + void encodingFormatIsCopiedAndMerged() { + OpenAiEmbeddingOptions source = OpenAiEmbeddingOptions.builder() + .model("test-model") + .encodingFormat(EmbeddingCreateParams.EncodingFormat.FLOAT) + .build(); + + OpenAiEmbeddingOptions copied = OpenAiEmbeddingOptions.builder().from(source).build(); + OpenAiEmbeddingOptions merged = OpenAiEmbeddingOptions.builder().model("other-model").merge(source).build(); + + assertThat(copied.getEncodingFormat()).isEqualTo(EmbeddingCreateParams.EncodingFormat.FLOAT); + assertThat(merged.getEncodingFormat()).isEqualTo(EmbeddingCreateParams.EncodingFormat.FLOAT); + } + +}