Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Pattern;
import jakarta.validation.constraints.Size;
import lombok.Builder;

import java.time.Instant;
import java.util.List;
import java.util.Map;
import java.util.UUID;

import static com.comet.opik.utils.ValidationUtils.NULL_OR_NOT_BLANK;

@Builder(toBuilder = true)
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy.class)
Expand All @@ -26,6 +30,9 @@ public record ProviderApiKey(
@JsonView({View.Public.class,
View.Write.class}) @NotBlank @JsonDeserialize(using = ProviderApiKeyDeserializer.class) String apiKey,
@JsonView({View.Public.class, View.Write.class}) @Size(max = 150) String name,
@JsonView({View.Public.class, View.Write.class}) Map<String, String> headers,
@JsonView({View.Public.class,
View.Write.class}) @Pattern(regexp = NULL_OR_NOT_BLANK, message = "must not be blank") String baseUrl,
@JsonView({View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant createdAt,
@JsonView({View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String createdBy,
@JsonView({View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) Instant lastUpdatedAt,
Expand All @@ -35,7 +42,11 @@ public record ProviderApiKey(
public String toString() {
return "ProviderApiKey{" +
"id=" + id +
", provider='" + provider + '\'' +
", provider=" + provider +
", apiKey='*******'" +
", name='" + name + '\'' +
", headers=" + headers +
", baseUrl='" + baseUrl + '\'' +
", createdAt=" + createdAt +
", createdBy='" + createdBy + '\'' +
", lastUpdatedAt=" + lastUpdatedAt +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import com.comet.opik.api.ProviderApiKey;
import com.comet.opik.api.ProviderApiKeyUpdate;
import com.comet.opik.infrastructure.db.MapFlatArgumentFactory;
import com.comet.opik.infrastructure.db.UUIDArgumentFactory;
import org.jdbi.v3.sqlobject.config.RegisterArgumentFactory;
import org.jdbi.v3.sqlobject.config.RegisterColumnMapper;
import org.jdbi.v3.sqlobject.config.RegisterConstructorMapper;
import org.jdbi.v3.sqlobject.customizer.Bind;
import org.jdbi.v3.sqlobject.customizer.BindList;
Expand All @@ -18,11 +20,13 @@

@RegisterConstructorMapper(ProviderApiKey.class)
@RegisterArgumentFactory(UUIDArgumentFactory.class)
@RegisterArgumentFactory(MapFlatArgumentFactory.class)
@RegisterColumnMapper(MapFlatArgumentFactory.class)
public interface LlmProviderApiKeyDAO {

@SqlUpdate("INSERT INTO llm_provider_api_key (id, provider, workspace_id, api_key, name, created_by, last_updated_by) "
@SqlUpdate("INSERT INTO llm_provider_api_key (id, provider, workspace_id, api_key, name, created_by, last_updated_by, headers, base_url) "
+
"VALUES (:bean.id, :bean.provider, :workspaceId, :bean.apiKey, :bean.name, :bean.createdBy, :bean.lastUpdatedBy)")
"VALUES (:bean.id, :bean.provider, :workspaceId, :bean.apiKey, :bean.name, :bean.createdBy, :bean.lastUpdatedBy, :bean.headers, :bean.baseUrl)")
void save(@Bind("workspaceId") String workspaceId,
@BindMethods("bean") ProviderApiKey providerApiKey);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package com.comet.opik.infrastructure.db;

import com.comet.opik.utils.JsonUtils;
import com.fasterxml.jackson.core.type.TypeReference;
import org.jdbi.v3.core.argument.AbstractArgumentFactory;
import org.jdbi.v3.core.argument.Argument;
import org.jdbi.v3.core.config.ConfigRegistry;
import org.jdbi.v3.core.mapper.ColumnMapper;
import org.jdbi.v3.core.statement.StatementContext;

import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Types;
import java.util.Map;
import java.util.Optional;

public class MapFlatArgumentFactory extends AbstractArgumentFactory<Map<String, String>>
implements
ColumnMapper<Map<String, String>> {

public static final TypeReference<Map<String, String>> TYPE_REFERENCE = new TypeReference<>() {
};

public MapFlatArgumentFactory() {
super(Types.VARCHAR);
}

@Override
protected Argument build(Map<String, String> value, ConfigRegistry config) {
return (position, statement, ctx) -> {
if (value == null) {
statement.setNull(position, Types.VARCHAR);
} else {
statement.setObject(position, JsonUtils.readTree(value).toString());
}
};
}

@Override
public Map<String, String> map(ResultSet r, int columnNumber, StatementContext ctx) throws SQLException {
return Optional.ofNullable(r.getString(columnNumber))
.map(value -> JsonUtils.readValue(value, TYPE_REFERENCE))
.orElse(null);
}

@Override
public Map<String, String> map(ResultSet r, String columnLabel, StatementContext ctx) throws SQLException {
return Optional.ofNullable(r.getString(columnLabel))
.map(value -> JsonUtils.readValue(value, TYPE_REFERENCE))
.orElse(null);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package com.comet.opik.infrastructure.llm;

import lombok.ToString;

import java.util.Map;

public record LlmProviderClientApiConfig(@ToString.Exclude String apiKey, Map<String, String> headers, String baseUrl) {

@Override
public String toString() {
return "LlmProviderClientConfig{" +
"apiKey='*********'" +
", headers=" + headers +
", baseUrl='" + baseUrl + '\'' +
'}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

public interface LlmProviderClientGenerator<T> {

T generate(String apiKey, Object... params);
T generate(LlmProviderClientApiConfig clientConfig, Object... params);

ChatLanguageModel generateChat(String apiKey, LlmAsJudgeModelParameters modelParameters);
ChatLanguageModel generateChat(LlmProviderClientApiConfig clientConfig, LlmAsJudgeModelParameters modelParameters);
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.comet.opik.infrastructure.llm;

import com.comet.opik.api.LlmProvider;
import com.comet.opik.api.ProviderApiKey;
import com.comet.opik.domain.LlmProviderApiKeyService;
import com.comet.opik.domain.llm.LlmProviderFactory;
import com.comet.opik.domain.llm.LlmProviderService;
Expand Down Expand Up @@ -35,21 +36,31 @@ public void register(LlmProvider llmProvider, LlmServiceProvider service) {

public LlmProviderService getService(@NonNull String workspaceId, @NonNull String model) {
var llmProvider = getLlmProvider(model);
var apiKey = EncryptionUtils.decrypt(getEncryptedApiKey(workspaceId, llmProvider));
var providerConfig = getProviderApiKey(workspaceId, llmProvider);

var config = new LlmProviderClientApiConfig(
EncryptionUtils.decrypt(providerConfig.apiKey()),
providerConfig.headers(),
providerConfig.baseUrl());

return Optional.ofNullable(services.get(llmProvider))
.map(provider -> provider.getService(apiKey))
.map(provider -> provider.getService(config))
.orElseThrow(() -> new LlmProviderUnsupportedException(
"LLM provider not supported: %s".formatted(llmProvider)));
}

public ChatLanguageModel getLanguageModel(@NonNull String workspaceId,
@NonNull LlmAsJudgeModelParameters modelParameters) {
var llmProvider = getLlmProvider(modelParameters.name());
var apiKey = EncryptionUtils.decrypt(getEncryptedApiKey(workspaceId, llmProvider));
var providerConfig = getProviderApiKey(workspaceId, llmProvider);

var config = new LlmProviderClientApiConfig(
EncryptionUtils.decrypt(providerConfig.apiKey()),
providerConfig.headers(),
providerConfig.baseUrl());

return Optional.ofNullable(services.get(llmProvider))
.map(provider -> provider.getLanguageModel(apiKey, modelParameters))
.map(provider -> provider.getLanguageModel(config, modelParameters))
.orElseThrow(() -> new BadRequestException(
String.format(ERROR_MODEL_NOT_SUPPORTED, modelParameters.name())));
}
Expand Down Expand Up @@ -78,13 +89,12 @@ private LlmProvider getLlmProvider(String model) {
* Finding API keys isn't paginated at the moment.
* Even in the future, the number of supported LLM providers per workspace is going to be very low.
*/
private String getEncryptedApiKey(String workspaceId, LlmProvider llmProvider) {
private ProviderApiKey getProviderApiKey(String workspaceId, LlmProvider llmProvider) {
return llmProviderApiKeyService.find(workspaceId).content().stream()
.filter(providerApiKey -> llmProvider.equals(providerApiKey.provider()))
.findFirst()
.orElseThrow(() -> new BadRequestException("API key not configured for LLM provider '%s'".formatted(
llmProvider.getValue())))
.apiKey();
llmProvider.getValue())));
}

private static <E extends Enum<E>> boolean isModelBelongToProvider(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

public interface LlmServiceProvider {

LlmProviderService getService(String apiKey);
LlmProviderService getService(LlmProviderClientApiConfig config);

ChatLanguageModel getLanguageModel(String apiKey, LlmAsJudgeModelParameters modelParameters);
ChatLanguageModel getLanguageModel(LlmProviderClientApiConfig config, LlmAsJudgeModelParameters modelParameters);
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.comet.opik.infrastructure.llm.antropic;

import com.comet.opik.infrastructure.LlmProviderClientConfig;
import com.comet.opik.infrastructure.llm.LlmProviderClientApiConfig;
import com.comet.opik.infrastructure.llm.LlmProviderClientGenerator;
import dev.langchain4j.model.anthropic.AnthropicChatModel;
import dev.langchain4j.model.anthropic.internal.client.AnthropicClient;
Expand All @@ -18,12 +19,17 @@ public class AnthropicClientGenerator implements LlmProviderClientGenerator<Anth

private final @NonNull LlmProviderClientConfig llmProviderClientConfig;

private AnthropicClient newAnthropicClient(@NonNull String apiKey) {
private AnthropicClient newAnthropicClient(@NonNull LlmProviderClientApiConfig config) {
var anthropicClientBuilder = AnthropicClient.builder();
Optional.ofNullable(llmProviderClientConfig.getAnthropicClient())
.map(LlmProviderClientConfig.AnthropicClientConfig::url)
.filter(StringUtils::isNotEmpty)
.ifPresent(anthropicClientBuilder::baseUrl);

if (StringUtils.isNotEmpty(config.baseUrl())) {
anthropicClientBuilder.baseUrl(config.baseUrl());
}

Optional.ofNullable(llmProviderClientConfig.getAnthropicClient())
.map(LlmProviderClientConfig.AnthropicClientConfig::version)
.filter(StringUtils::isNotBlank)
Expand All @@ -36,13 +42,14 @@ private AnthropicClient newAnthropicClient(@NonNull String apiKey) {
Optional.ofNullable(llmProviderClientConfig.getCallTimeout())
.ifPresent(callTimeout -> anthropicClientBuilder.timeout(callTimeout.toJavaDuration()));
return anthropicClientBuilder
.apiKey(apiKey)
.apiKey(config.apiKey())
.build();
}

private ChatLanguageModel newChatLanguageModel(String apiKey, LlmAsJudgeModelParameters modelParameters) {
private ChatLanguageModel newChatLanguageModel(LlmProviderClientApiConfig config,
LlmAsJudgeModelParameters modelParameters) {
var builder = AnthropicChatModel.builder()
.apiKey(apiKey)
.apiKey(config.apiKey())
.modelName(modelParameters.name());

Optional.ofNullable(llmProviderClientConfig.getConnectTimeout())
Expand All @@ -53,18 +60,23 @@ private ChatLanguageModel newChatLanguageModel(String apiKey, LlmAsJudgeModelPar
.filter(StringUtils::isNotBlank)
.ifPresent(builder::baseUrl);

if (StringUtils.isNotEmpty(config.baseUrl())) {
builder.baseUrl(config.baseUrl());
}

Optional.ofNullable(modelParameters.temperature()).ifPresent(builder::temperature);

return builder.build();
}

@Override
public AnthropicClient generate(@NonNull String apiKey, Object... params) {
return newAnthropicClient(apiKey);
public AnthropicClient generate(@NonNull LlmProviderClientApiConfig config, Object... params) {
return newAnthropicClient(config);
}

@Override
public ChatLanguageModel generateChat(@NonNull String apiKey, @NonNull LlmAsJudgeModelParameters modelParameters) {
return newChatLanguageModel(apiKey, modelParameters);
public ChatLanguageModel generateChat(@NonNull LlmProviderClientApiConfig config,
@NonNull LlmAsJudgeModelParameters modelParameters) {
return newChatLanguageModel(config, modelParameters);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.comet.opik.api.LlmProvider;
import com.comet.opik.domain.llm.LlmProviderFactory;
import com.comet.opik.domain.llm.LlmProviderService;
import com.comet.opik.infrastructure.llm.LlmProviderClientApiConfig;
import com.comet.opik.infrastructure.llm.LlmServiceProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;
import lombok.NonNull;
Expand All @@ -21,14 +22,14 @@ class AnthropicLlmServiceProvider implements LlmServiceProvider {
}

@Override
public LlmProviderService getService(String apiKey) {
return new LlmProviderAnthropic(clientGenerator.generate(apiKey));
public LlmProviderService getService(LlmProviderClientApiConfig config) {
return new LlmProviderAnthropic(clientGenerator.generate(config));
}

@Override
public ChatLanguageModel getLanguageModel(@NonNull String apiKey,
public ChatLanguageModel getLanguageModel(@NonNull LlmProviderClientApiConfig config,
@NonNull LlmAsJudgeModelParameters modelParameters) {
return clientGenerator.generateChat(apiKey, modelParameters);
return clientGenerator.generateChat(config, modelParameters);
}

}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.comet.opik.infrastructure.llm.gemini;

import com.comet.opik.infrastructure.LlmProviderClientConfig;
import com.comet.opik.infrastructure.llm.LlmProviderClientApiConfig;
import com.comet.opik.infrastructure.llm.LlmProviderClientGenerator;
import com.google.common.base.Preconditions;
import dev.ai4j.openai4j.chat.ChatCompletionRequest;
Expand Down Expand Up @@ -34,18 +35,19 @@ public GoogleAiGeminiStreamingChatModel newGeminiStreamingClient(
}

@Override
public GoogleAiGeminiChatModel generate(String apiKey, Object... params) {
public GoogleAiGeminiChatModel generate(LlmProviderClientApiConfig config, Object... params) {
Preconditions.checkArgument(params.length >= 1, "Expected at least 1 parameter, got " + params.length);
ChatCompletionRequest request = (ChatCompletionRequest) Objects.requireNonNull(params[0],
"ChatCompletionRequest is required");
return newGeminiClient(apiKey, request);
return newGeminiClient(config.apiKey(), request);
}

@Override
public ChatLanguageModel generateChat(String apiKey, LlmAsJudgeModelParameters modelParameters) {
public ChatLanguageModel generateChat(LlmProviderClientApiConfig config,
LlmAsJudgeModelParameters modelParameters) {
GoogleAiGeminiChatModelBuilder modelBuilder = GoogleAiGeminiChatModel.builder()
.modelName(modelParameters.name())
.apiKey(apiKey);
.apiKey(config.apiKey());

Optional.ofNullable(llmProviderClientConfig.getConnectTimeout())
.ifPresent(connectTimeout -> modelBuilder.timeout(connectTimeout.toJavaDuration()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.comet.opik.api.LlmProvider;
import com.comet.opik.domain.llm.LlmProviderFactory;
import com.comet.opik.domain.llm.LlmProviderService;
import com.comet.opik.infrastructure.llm.LlmProviderClientApiConfig;
import com.comet.opik.infrastructure.llm.LlmServiceProvider;
import dev.langchain4j.model.chat.ChatLanguageModel;

Expand All @@ -17,13 +18,13 @@ public class GeminiLlmServiceProvider implements LlmServiceProvider {
}

@Override
public LlmProviderService getService(String apiKey) {
return new LlmProviderGemini(clientGenerator, apiKey);
public LlmProviderService getService(LlmProviderClientApiConfig config) {
return new LlmProviderGemini(clientGenerator, config);
}

@Override
public ChatLanguageModel getLanguageModel(String apiKey,
public ChatLanguageModel getLanguageModel(LlmProviderClientApiConfig config,
AutomationRuleEvaluatorLlmAsJudge.LlmAsJudgeModelParameters modelParameters) {
return clientGenerator.generateChat(apiKey, modelParameters);
return clientGenerator.generateChat(config, modelParameters);
}
}
Loading