Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
java_build:
strategy:
matrix:
java_version: [ 17, 21, 22, 23 ]
java_version: [ 17, 21, 25 ]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,5 @@ build/

### .env files contain local environment variables ###
.env

*.log
Binary file modified .mvn/wrapper/maven-wrapper.jar
Binary file not shown.
4 changes: 2 additions & 2 deletions .mvn/wrapper/maven-wrapper.properties
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@
# under the License.
wrapperVersion=3.3.2
distributionType=only-script
distributionUrl=https://repo.maven.apache.org/maven2/org/apache/maven/apache-maven/3.9.11/apache-maven-3.9.11-bin.zip
distributionSha256Sum=0d7125e8c91097b36edb990ea5934e6c68b4440eef4ea96510a0f6815e7eeadb
distributionUrl=https://repo.maven.apache.org/maven2/org/apache/maven/apache-maven/3.9.12/apache-maven-3.9.12-bin.zip
distributionSha256Sum=305773a68d6ddfd413df58c82b3f8050e89778e777f3a745c8e5b8cbea4018ef
2 changes: 1 addition & 1 deletion .sdkmanrc
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
# See https://sdkman.io/usage#config
# A summary is to add the following to ~/.sdkman/etc/config
# sdkman_auto_env=true
java=17.0.14-tem
java=17.0.17-tem
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import java.util.List;
import java.util.Map;

class ChatModelProperties {
public class ChatModelProperties {

String baseUrl;
String apiKey;
Expand Down
2 changes: 1 addition & 1 deletion langchain4j-elasticsearch-spring-boot-starter/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.18.0</version>
<version>2.21.0</version>
<scope>test</scope>
</dependency>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import java.time.Duration;
import java.util.List;

class ChatModelProperties {
public class ChatModelProperties {

private String endpoint;
private String gitHubToken;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import java.time.Duration;

class EmbeddingModelProperties {
public class EmbeddingModelProperties {

private String endpoint;
private String gitHubToken;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import java.util.List;
import java.util.Map;

record ChatModelProperties(
public record ChatModelProperties(
String apiKey,
String baseUrl,
String modelName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import java.time.Duration;

record EmbeddingModelProperties(
public record EmbeddingModelProperties(
String apiKey,
String modelName,
String titleMetadataKey,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import java.util.List;

record GeminiFunctionCallingConfig(
public record GeminiFunctionCallingConfig(
GeminiMode geminiMode,
List<String> allowedFunctionNames
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import dev.langchain4j.http.client.SuccessfulHttpResponse;
import dev.langchain4j.http.client.sse.ServerSentEventListener;
import dev.langchain4j.http.client.sse.ServerSentEventParser;
import org.springframework.boot.web.client.ClientHttpRequestFactories;
import org.springframework.boot.web.client.ClientHttpRequestFactorySettings;
import org.springframework.boot.http.client.ClientHttpRequestFactoryBuilder;
import org.springframework.boot.http.client.ClientHttpRequestFactorySettings;
import org.springframework.core.io.ByteArrayResource;
import org.springframework.core.task.AsyncTaskExecutor;
import org.springframework.http.ResponseEntity;
Expand All @@ -36,14 +36,14 @@ public SpringRestClient(SpringRestClientBuilder builder) {

RestClient.Builder restClientBuilder = getOrDefault(builder.restClientBuilder(), RestClient::builder);

ClientHttpRequestFactorySettings settings = ClientHttpRequestFactorySettings.DEFAULTS;
ClientHttpRequestFactorySettings settings = ClientHttpRequestFactorySettings.defaults();
if (builder.connectTimeout() != null) {
settings = settings.withConnectTimeout(builder.connectTimeout());
}
if (builder.readTimeout() != null) {
settings = settings.withReadTimeout(builder.readTimeout());
}
ClientHttpRequestFactory clientHttpRequestFactory = ClientHttpRequestFactories.get(settings);
ClientHttpRequestFactory clientHttpRequestFactory = ClientHttpRequestFactoryBuilder.detect().build(settings);

this.delegate = restClientBuilder
.requestFactory(clientHttpRequestFactory)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.http.client.JdkClientHttpRequestFactory;
import org.springframework.http.client.ReactorNettyClientRequestFactory;
import org.springframework.http.client.ReactorClientHttpRequestFactory;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.web.client.RestClient;

Expand All @@ -23,7 +23,7 @@ protected List<HttpClient> clients() {
.restClientBuilder(RestClient.builder().requestFactory(new HttpComponentsClientHttpRequestFactory()))
.build(),
SpringRestClient.builder()
.restClientBuilder(RestClient.builder().requestFactory(new ReactorNettyClientRequestFactory()))
.restClientBuilder(RestClient.builder().requestFactory(new ReactorClientHttpRequestFactory()))
.build(),
SpringRestClient.builder()
.restClientBuilder(RestClient.builder().requestFactory(new SimpleClientHttpRequestFactory()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import dev.langchain4j.http.client.HttpClientTimeoutIT;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.http.client.JdkClientHttpRequestFactory;
import org.springframework.http.client.ReactorNettyClientRequestFactory;
import org.springframework.http.client.ReactorClientHttpRequestFactory;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.web.client.RestClient;

Expand All @@ -26,7 +26,7 @@ protected List<HttpClient> clients(Duration readTimeout) {
.readTimeout(readTimeout)
.build(),
SpringRestClient.builder()
.restClientBuilder(RestClient.builder().requestFactory(new ReactorNettyClientRequestFactory()))
.restClientBuilder(RestClient.builder().requestFactory(new ReactorClientHttpRequestFactory()))
.readTimeout(readTimeout)
.build(),
SpringRestClient.builder()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
package dev.langchain4j.mistralai.spring;

import static dev.langchain4j.mistralai.spring.Properties.PREFIX;

import dev.langchain4j.http.client.HttpClientBuilder;
import dev.langchain4j.http.client.spring.restclient.SpringRestClient;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.mistralai.MistralAiChatModel;
import dev.langchain4j.model.mistralai.MistralAiEmbeddingModel;
import dev.langchain4j.model.mistralai.MistralAiFimModel;
import dev.langchain4j.model.mistralai.MistralAiModerationModel;
import dev.langchain4j.model.mistralai.MistralAiStreamingChatModel;
import dev.langchain4j.model.mistralai.MistralAiStreamingFimModel;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.autoconfigure.AutoConfiguration;
Expand All @@ -16,15 +23,7 @@
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.web.client.RestClient;

import dev.langchain4j.http.client.HttpClientBuilder;
import dev.langchain4j.http.client.spring.restclient.SpringRestClient;
import dev.langchain4j.model.chat.listener.ChatModelListener;
import dev.langchain4j.model.mistralai.MistralAiChatModel;
import dev.langchain4j.model.mistralai.MistralAiEmbeddingModel;
import dev.langchain4j.model.mistralai.MistralAiFimModel;
import dev.langchain4j.model.mistralai.MistralAiModerationModel;
import dev.langchain4j.model.mistralai.MistralAiStreamingChatModel;
import dev.langchain4j.model.mistralai.MistralAiStreamingFimModel;
import static dev.langchain4j.mistralai.spring.Properties.PREFIX;

@AutoConfiguration
@EnableConfigurationProperties(Properties.class)
Expand Down Expand Up @@ -55,29 +54,29 @@ MistralAiChatModel mistralAiChatModel(
ChatModelProperties chatModelProperties = properties.getChatModel();
MistralAiChatModel.MistralAiChatModelBuilder builder = MistralAiChatModel.builder()
.httpClientBuilder(httpClientBuilder)
.baseUrl(chatModelProperties.getBaseUrl())
.apiKey(chatModelProperties.getApiKey())
.modelName(chatModelProperties.getModelName())
.temperature(chatModelProperties.getTemperature())
.topP(chatModelProperties.getTopP())
.maxTokens(chatModelProperties.getMaxTokens())
.safePrompt(chatModelProperties.getSafePrompt())
.randomSeed(chatModelProperties.getRandomSeed())
.responseFormat(chatModelProperties.getResponseFormat())
.stopSequences(chatModelProperties.getStopSequences())
.frequencyPenalty(chatModelProperties.getFrequencyPenalty())
.presencePenalty(chatModelProperties.getPresencePenalty())
.timeout(chatModelProperties.getTimeout())
.logRequests(chatModelProperties.getLogRequests())
.logResponses(chatModelProperties.getLogResponses())
.baseUrl(chatModelProperties.baseUrl())
.apiKey(chatModelProperties.apiKey())
.modelName(chatModelProperties.modelName())
.temperature(chatModelProperties.temperature())
.topP(chatModelProperties.topP())
.maxTokens(chatModelProperties.maxTokens())
.safePrompt(chatModelProperties.safePrompt())
.randomSeed(chatModelProperties.randomSeed())
.responseFormat(chatModelProperties.responseFormat())
.stopSequences(chatModelProperties.stopSequences())
.frequencyPenalty(chatModelProperties.frequencyPenalty())
.presencePenalty(chatModelProperties.presencePenalty())
.timeout(chatModelProperties.timeout())
.logRequests(chatModelProperties.logRequests())
.logResponses(chatModelProperties.logResponses())
.listeners(listeners.orderedStream().toList());

// Conditional parameters to avoid NPE in Mistral AI models
if (chatModelProperties.getMaxRetries() != null) {
builder.maxRetries(chatModelProperties.getMaxRetries());
if (chatModelProperties.maxRetries() != null) {
builder.maxRetries(chatModelProperties.maxRetries());
}
if (chatModelProperties.getSupportedCapabilities() != null) {
builder.supportedCapabilities(chatModelProperties.getSupportedCapabilities());
if (chatModelProperties.supportedCapabilities() != null) {
builder.supportedCapabilities(chatModelProperties.supportedCapabilities());
}

return builder.build();
Expand All @@ -102,26 +101,26 @@ MistralAiStreamingChatModel mistralAiStreamingChatModel(
ChatModelProperties chatModelProperties = properties.getStreamingChatModel();
MistralAiStreamingChatModel.MistralAiStreamingChatModelBuilder builder = MistralAiStreamingChatModel.builder()
.httpClientBuilder(httpClientBuilder)
.baseUrl(chatModelProperties.getBaseUrl())
.apiKey(chatModelProperties.getApiKey())
.modelName(chatModelProperties.getModelName())
.temperature(chatModelProperties.getTemperature())
.topP(chatModelProperties.getTopP())
.maxTokens(chatModelProperties.getMaxTokens())
.safePrompt(chatModelProperties.getSafePrompt())
.randomSeed(chatModelProperties.getRandomSeed())
.responseFormat(chatModelProperties.getResponseFormat())
.stopSequences(chatModelProperties.getStopSequences())
.frequencyPenalty(chatModelProperties.getFrequencyPenalty())
.presencePenalty(chatModelProperties.getPresencePenalty())
.timeout(chatModelProperties.getTimeout())
.logRequests(chatModelProperties.getLogRequests())
.logResponses(chatModelProperties.getLogResponses())
.baseUrl(chatModelProperties.baseUrl())
.apiKey(chatModelProperties.apiKey())
.modelName(chatModelProperties.modelName())
.temperature(chatModelProperties.temperature())
.topP(chatModelProperties.topP())
.maxTokens(chatModelProperties.maxTokens())
.safePrompt(chatModelProperties.safePrompt())
.randomSeed(chatModelProperties.randomSeed())
.responseFormat(chatModelProperties.responseFormat())
.stopSequences(chatModelProperties.stopSequences())
.frequencyPenalty(chatModelProperties.frequencyPenalty())
.presencePenalty(chatModelProperties.presencePenalty())
.timeout(chatModelProperties.timeout())
.logRequests(chatModelProperties.logRequests())
.logResponses(chatModelProperties.logResponses())
.listeners(listeners.orderedStream().toList());

// Conditional parameters to avoid NPE in Mistral AI models
if (chatModelProperties.getSupportedCapabilities() != null) {
builder.supportedCapabilities(chatModelProperties.getSupportedCapabilities());
if (chatModelProperties.supportedCapabilities() != null) {
builder.supportedCapabilities(chatModelProperties.supportedCapabilities());
}

return builder.build();
Expand Down Expand Up @@ -169,13 +168,13 @@ MistralAiEmbeddingModel mistralAiEmbeddingModel(
EmbeddingModelProperties embeddingModelProperties = properties.getEmbeddingModel();
return MistralAiEmbeddingModel.builder()
.httpClientBuilder(httpClientBuilder)
.baseUrl(embeddingModelProperties.getBaseUrl())
.apiKey(embeddingModelProperties.getApiKey())
.modelName(embeddingModelProperties.getModelName())
.timeout(embeddingModelProperties.getTimeout())
.logRequests(embeddingModelProperties.getLogRequests())
.logResponses(embeddingModelProperties.getLogResponses())
.maxRetries(embeddingModelProperties.getMaxRetries())
.baseUrl(embeddingModelProperties.baseUrl())
.apiKey(embeddingModelProperties.apiKey())
.modelName(embeddingModelProperties.modelName())
.timeout(embeddingModelProperties.timeout())
.logRequests(embeddingModelProperties.logRequests())
.logResponses(embeddingModelProperties.logResponses())
.maxRetries(embeddingModelProperties.maxRetries())
.build();
}

Expand All @@ -197,19 +196,19 @@ MistralAiFimModel mistralAiFimModel(
FimModelProperties fimModelProperties = properties.getFimModel();
return MistralAiFimModel.builder()
.httpClientBuilder(httpClientBuilder)
.baseUrl(fimModelProperties.getBaseUrl())
.apiKey(fimModelProperties.getApiKey())
.modelName(fimModelProperties.getModelName())
.temperature(fimModelProperties.getTemperature())
.maxTokens(fimModelProperties.getMaxTokens())
.minTokens(fimModelProperties.getMinTokens())
.topP(fimModelProperties.getTopP())
.randomSeed(fimModelProperties.getRandomSeed())
.stop(fimModelProperties.getStop())
.timeout(fimModelProperties.getTimeout())
.logRequests(fimModelProperties.getLogRequests())
.logResponses(fimModelProperties.getLogResponses())
.maxRetries(fimModelProperties.getMaxRetries())
.baseUrl(fimModelProperties.baseUrl())
.apiKey(fimModelProperties.apiKey())
.modelName(fimModelProperties.modelName())
.temperature(fimModelProperties.temperature())
.maxTokens(fimModelProperties.maxTokens())
.minTokens(fimModelProperties.minTokens())
.topP(fimModelProperties.topP())
.randomSeed(fimModelProperties.randomSeed())
.stop(fimModelProperties.stop())
.timeout(fimModelProperties.timeout())
.logRequests(fimModelProperties.logRequests())
.logResponses(fimModelProperties.logResponses())
.maxRetries(fimModelProperties.maxRetries())
.build();
}

Expand All @@ -231,18 +230,18 @@ MistralAiStreamingFimModel mistralAiStreamingFimModel(
FimModelProperties fimModelProperties = properties.getStreamingFimModel();
return MistralAiStreamingFimModel.builder()
.httpClientBuilder(httpClientBuilder)
.baseUrl(fimModelProperties.getBaseUrl())
.apiKey(fimModelProperties.getApiKey())
.modelName(fimModelProperties.getModelName())
.temperature(fimModelProperties.getTemperature())
.maxTokens(fimModelProperties.getMaxTokens())
.minTokens(fimModelProperties.getMinTokens())
.topP(fimModelProperties.getTopP())
.randomSeed(fimModelProperties.getRandomSeed())
.stop(fimModelProperties.getStop())
.timeout(fimModelProperties.getTimeout())
.logRequests(fimModelProperties.getLogRequests())
.logResponses(fimModelProperties.getLogResponses())
.baseUrl(fimModelProperties.baseUrl())
.apiKey(fimModelProperties.apiKey())
.modelName(fimModelProperties.modelName())
.temperature(fimModelProperties.temperature())
.maxTokens(fimModelProperties.maxTokens())
.minTokens(fimModelProperties.minTokens())
.topP(fimModelProperties.topP())
.randomSeed(fimModelProperties.randomSeed())
.stop(fimModelProperties.stop())
.timeout(fimModelProperties.timeout())
.logRequests(fimModelProperties.logRequests())
.logResponses(fimModelProperties.logResponses())
.build();
}

Expand Down Expand Up @@ -288,16 +287,16 @@ MistralAiModerationModel mistralAiModerationModel(
ModerationModelProperties moderationModelProperties = properties.getModerationModel();
MistralAiModerationModel.Builder builder = new MistralAiModerationModel.Builder()
.httpClientBuilder(httpClientBuilder)
.baseUrl(moderationModelProperties.getBaseUrl())
.apiKey(moderationModelProperties.getApiKey())
.modelName(moderationModelProperties.getModelName())
.timeout(moderationModelProperties.getTimeout())
.logRequests(moderationModelProperties.getLogRequests())
.logResponses(moderationModelProperties.getLogResponses());
.baseUrl(moderationModelProperties.baseUrl())
.apiKey(moderationModelProperties.apiKey())
.modelName(moderationModelProperties.modelName())
.timeout(moderationModelProperties.timeout())
.logRequests(moderationModelProperties.logRequests())
.logResponses(moderationModelProperties.logResponses());

// Conditional parameter to avoid NPE in Mistral AI models
if (moderationModelProperties.getMaxRetries() != null) {
builder.maxRetries(moderationModelProperties.getMaxRetries());
if (moderationModelProperties.maxRetries() != null) {
builder.maxRetries(moderationModelProperties.maxRetries());
}

return builder.build();
Expand Down
Loading