diff --git a/src/main/java/org/chappiebot/ChappieService.java b/src/main/java/org/chappiebot/ChappieService.java index 68514a4..c9ae96c 100755 --- a/src/main/java/org/chappiebot/ChappieService.java +++ b/src/main/java/org/chappiebot/ChappieService.java @@ -6,11 +6,13 @@ import dev.langchain4j.mcp.client.transport.McpTransport; import dev.langchain4j.mcp.client.transport.http.StreamableHttpMcpTransport; import dev.langchain4j.mcp.client.transport.stdio.StdioMcpTransport; + import java.time.Duration; import java.util.Optional; import org.chappiebot.assist.Assistant; import org.chappiebot.exception.ExceptionAssistant; +import org.chappiebot.rag.RetrievalProvider; import org.eclipse.microprofile.config.inject.ConfigProperty; import dev.langchain4j.memory.chat.ChatMemoryProvider; import dev.langchain4j.memory.chat.MessageWindowChatMemory; @@ -30,15 +32,13 @@ import jakarta.inject.Inject; import dev.langchain4j.rag.RetrievalAugmentor; -import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever; -import dev.langchain4j.model.embedding.EmbeddingModel; -import dev.langchain4j.model.embedding.onnx.bgesmallenv15q.BgeSmallEnV15QuantizedEmbeddingModel; import dev.langchain4j.store.embedding.filter.comparison.ContainsString; -import dev.langchain4j.rag.DefaultRetrievalAugmentor; import dev.langchain4j.store.memory.chat.ChatMemoryStore; import jakarta.annotation.PreDestroy; + import java.util.List; import java.util.Map; + import org.chappiebot.rag.RagRequestContext; import org.chappiebot.store.StoreCreator; @@ -64,7 +64,7 @@ public class ChappieService { @ConfigProperty(name = "chappie.temperature", defaultValue = "0.2") double temperature; - + // OpenAI @ConfigProperty(name = "chappie.openai.api-key") @@ -86,43 +86,40 @@ public class ChappieService { // RAG - @ConfigProperty(name = "chappie.rag.enabled", defaultValue = "true") - boolean ragEnabled; - - @ConfigProperty(name = "chappie.rag.results.max", defaultValue = "4") - int ragMaxResults; - + @Inject + RetrievalProvider retrievalProvider; + // Store - + @ConfigProperty(name = "chappie.store.messages.max", defaultValue = "30") int maxMessages; - - + + @ConfigProperty(name = "quarkus.application.version") String appVersion; - + // MCP @ConfigProperty(name = "chappie.mcp.servers") Optional> mcpServers; - + @Inject StoreCreator storeCreator; @Inject ChatMemoryStore chatMemoryStore; - - @Inject + + @Inject RagRequestContext ragRequestContext; - + private RetrievalAugmentor retrievalAugmentor; private final List mcpClients = new java.util.concurrent.CopyOnWriteArrayList<>(); private McpToolProvider mcpToolProvider = null; - + private ChatRequestParameters chatRequestParameters = DefaultChatRequestParameters.builder() - .toolChoice(ToolChoice.AUTO) - .responseFormat(ResponseFormat.JSON) - .build(); - + .toolChoice(ToolChoice.AUTO) + .responseFormat(ResponseFormat.JSON) + .build(); + @PostConstruct public void init() { if (openaiKey.isPresent() || openaiBaseUrl.isPresent()) { @@ -138,10 +135,13 @@ public void init() { void shutdown() { // Be nice and close transports/clients for (McpClient c : mcpClients) { - try { c.close(); } catch (Exception ignored) {} + try { + c.close(); + } catch (Exception ignored) { + } } } - + private void loadOpenAiModel() { openaiBaseUrl.ifPresentOrElse( @@ -151,10 +151,9 @@ private void loadOpenAiModel() { Log.info("CHAPPiE timeout set to " + timeout); Log.info("CHAPPiE temperature set to " + temperature); - if(openaiKey.isEmpty())Log.warn("CHAPPiE is using the default 'demo' api key"); - - - + if (openaiKey.isEmpty()) Log.warn("CHAPPiE is using the default 'demo' api key"); + + OpenAiChatModel.OpenAiChatModelBuilder builder = OpenAiChatModel.builder() .logRequests(logRequest) .logResponses(logResponse) @@ -163,13 +162,13 @@ private void loadOpenAiModel() { .timeout(timeout) .temperature(temperature) .responseFormat("json_object"); - + if (!mcpServers.isEmpty() && !mcpServers.get().isEmpty()) { builder = builder - .defaultRequestParameters(chatRequestParameters) - .parallelToolCalls(false); + .defaultRequestParameters(chatRequestParameters) + .parallelToolCalls(false); } - + if (openaiBaseUrl.isPresent()) { builder = builder.baseUrl(openaiBaseUrl.get()); } @@ -181,7 +180,7 @@ private void loadOllamaModel() { Log.info("CHAPPiE is using Ollama " + ollamaModelName + "(" + ollamaBaseUrl + ")"); Log.info("CHAPPiE timeout set to " + timeout); Log.info("CHAPPiE temperature set to " + temperature); - + OllamaChatModel.OllamaChatModelBuilder builder = OllamaChatModel.builder() .logRequests(logRequest) .logResponses(logResponse) @@ -190,21 +189,21 @@ private void loadOllamaModel() { .timeout(timeout) .temperature(temperature) .responseFormat(ResponseFormat.JSON); - + if (!mcpServers.isEmpty() && !mcpServers.get().isEmpty()) { builder = builder - .defaultRequestParameters(chatRequestParameters); + .defaultRequestParameters(chatRequestParameters); } - + this.chatModel = builder.build(); } @Produces public ExceptionAssistant getExceptionAssistant() { - + AiServices assistantBuilder = AiServices.builder(ExceptionAssistant.class) .chatModel(chatModel); - + if (retrievalAugmentor != null) { assistantBuilder.retrievalAugmentor(retrievalAugmentor); } @@ -213,14 +212,14 @@ public ExceptionAssistant getExceptionAssistant() { } return assistantBuilder.build(); } - + @Produces public Assistant getAssistant() { - + AiServices assistantBuilder = AiServices.builder(Assistant.class) .chatModel(chatModel) .chatMemoryProvider(chatMemoryProvider()); - + if (retrievalAugmentor != null) { assistantBuilder.retrievalAugmentor(retrievalAugmentor); } @@ -235,44 +234,26 @@ private void enableRagIfPossible() { Log.info("CHAPPiE RAG not available; continuing without RAG."); return; } - + // TODO: This should use some local emmeding model if (openaiKey.isEmpty() && openaiBaseUrl.isEmpty()) { Log.warn("CHAPPiE RAG available but no OpenAI configuration for embeddings; continuing without RAG."); return; } - if(!ragEnabled) { - Log.warn("CHAPPiE RAG disabled by the user"); - return; - } - - EmbeddingModel embeddingModel = new BgeSmallEnV15QuantizedEmbeddingModel(); - - var retriever = EmbeddingStoreContentRetriever.builder() - .embeddingStore(storeCreator.getStore().get()) - .embeddingModel(embeddingModel) - .maxResults(ragMaxResults) - .dynamicFilter((t) -> { - Map variables = ragRequestContext.getVariables(); - if(variables!=null && !variables.isEmpty() && variables.containsKey("extension")){ - String extension = variables.get("extension"); - if(extension!=null && !extension.equalsIgnoreCase("any")){ - Log.info("Narrowing to [" + extension + "]"); - return new ContainsString("extensions_csv_padded", "," + extension + ","); - } - } - return null; - }) - .build(); - - this.retrievalAugmentor = DefaultRetrievalAugmentor.builder() - .contentRetriever(retriever) - .build(); - - Log.info("CHAPPiE RAG is enabled with " + ragMaxResults + " max results"); + this.retrievalAugmentor = retrievalProvider.getRetrievalAugmentor((t) -> { + Map variables = ragRequestContext.getVariables(); + if (variables != null && !variables.isEmpty() && variables.containsKey("extension")) { + String extension = variables.get("extension"); + if (extension != null && !extension.equalsIgnoreCase("any")) { + Log.info("Narrowing to [" + extension + "]"); + return new ContainsString("extensions_csv_padded", "," + extension + ","); + } + } + return null; + }); } - + private void enableMcpIfConfigured() { if (mcpServers.isEmpty() || mcpServers.get().isEmpty()) { Log.info("CHAPPiE MCP: no servers configured; continuing without MCP."); @@ -333,18 +314,17 @@ private void enableMcpIfConfigured() { Log.infof("CHAPPiE MCP: enabled with %d server(s).", clients.size()); } - - - + + private ChatMemoryProvider chatMemoryProvider() { Log.info("CHAPPiE Chat Memory is enabled with " + maxMessages + " max messages"); return memoryId -> MessageWindowChatMemory.builder() - .id(memoryId) - .maxMessages(maxMessages) - .chatMemoryStore(chatMemoryStore) - .build(); + .id(memoryId) + .maxMessages(maxMessages) + .chatMemoryStore(chatMemoryStore) + .build(); } - + private String versionOr(String fallback) { String v = appVersion; if (v != null && !v.isBlank()) return v; diff --git a/src/main/java/org/chappiebot/rag/RetrievalProvider.java b/src/main/java/org/chappiebot/rag/RetrievalProvider.java new file mode 100644 index 0000000..326886f --- /dev/null +++ b/src/main/java/org/chappiebot/rag/RetrievalProvider.java @@ -0,0 +1,116 @@ +package org.chappiebot.rag; + +import dev.langchain4j.data.embedding.Embedding; +import dev.langchain4j.data.segment.TextSegment; +import dev.langchain4j.model.embedding.EmbeddingModel; +import dev.langchain4j.model.embedding.onnx.bgesmallenv15q.BgeSmallEnV15QuantizedEmbeddingModel; +import dev.langchain4j.rag.DefaultRetrievalAugmentor; +import dev.langchain4j.rag.RetrievalAugmentor; +import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever; +import dev.langchain4j.rag.query.Query; +import dev.langchain4j.store.embedding.EmbeddingMatch; +import dev.langchain4j.store.embedding.EmbeddingSearchRequest; +import dev.langchain4j.store.embedding.EmbeddingSearchResult; +import dev.langchain4j.store.embedding.EmbeddingStore; +import dev.langchain4j.store.embedding.filter.Filter; +import dev.langchain4j.store.embedding.filter.comparison.ContainsString; +import io.quarkus.logging.Log; +import jakarta.annotation.PostConstruct; +import jakarta.enterprise.context.ApplicationScoped; +import jakarta.inject.Inject; +import org.chappiebot.search.SearchMatch; +import org.chappiebot.store.StoreCreator; +import org.eclipse.microprofile.config.inject.ConfigProperty; + +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; + +@ApplicationScoped +public class RetrievalProvider { + + @Inject + StoreCreator storeCreator; + + EmbeddingModel embeddingModel; + + private EmbeddingStore embeddingStore; + + @ConfigProperty(name = "chappie.rag.enabled", defaultValue = "true") + boolean ragEnabled; + + @ConfigProperty(name = "chappie.rag.results.max", defaultValue = "4") + int ragMaxResults; + + @PostConstruct + public void init() { + if (ragEnabled) { + loadEmbeddingModel(); + loadVectorStore(); + } + } + + public int getRagMaxResults() { + return ragMaxResults; + } + + private void loadVectorStore() { + embeddingStore = storeCreator.getStore().get(); + } + + private void loadEmbeddingModel() { + embeddingModel = new BgeSmallEnV15QuantizedEmbeddingModel(); + } + + // TODO do we need to pad the extensions with commas to avoid partial matches? + private Filter extensionFilter(String extension) { + return new ContainsString("extensions_csv_padded", extension); + } + + public List search(String queryMessage, int maxResults, String restrictToExtension) { + Embedding embeddedQuery = embeddingModel.embed(queryMessage).content(); + + EmbeddingSearchRequest.EmbeddingSearchRequestBuilder requestBuilder = EmbeddingSearchRequest.builder() + .queryEmbedding(embeddedQuery) + .maxResults(maxResults) + .minScore(0.0); + if (restrictToExtension != null) { + Log.info("Restricting search to extension: " + restrictToExtension); + requestBuilder.filter(extensionFilter(restrictToExtension)); + } + EmbeddingSearchRequest searchRequest = requestBuilder.build(); + + EmbeddingSearchResult searchResult = embeddingStore.search(searchRequest); + + return searchResult.matches().stream() + .map(RetrievalProvider::extractContent) + .collect(Collectors.toList()); + } + + private static SearchMatch extractContent(EmbeddingMatch embeddingMatch) { + Map metadata = embeddingMatch.embedded().metadata().toMap(); + // Remove the actual embedding vector from metadata to reduce payload size + metadata.remove("embedding"); + return new SearchMatch(embeddingMatch.embedded().text(), embeddingMatch.embeddingId(), embeddingMatch.score(), + metadata); + } + + public RetrievalAugmentor getRetrievalAugmentor(Function filterFunction) { + if (ragEnabled && embeddingModel != null) { + Log.info("CHAPPiE RAG is enabled with " + ragMaxResults + " max results"); + + var retriever = EmbeddingStoreContentRetriever.builder() + .embeddingStore(embeddingStore) + .embeddingModel(embeddingModel) + .maxResults(ragMaxResults) + .dynamicFilter(filterFunction) + .build(); + + return DefaultRetrievalAugmentor.builder() + .contentRetriever(retriever) + .build(); + } + return null; + } +} diff --git a/src/main/java/org/chappiebot/search/SearchEndpoint.java b/src/main/java/org/chappiebot/search/SearchEndpoint.java new file mode 100644 index 0000000..d62ba85 --- /dev/null +++ b/src/main/java/org/chappiebot/search/SearchEndpoint.java @@ -0,0 +1,30 @@ +package org.chappiebot.search; + +import io.quarkus.logging.Log; +import jakarta.inject.Inject; +import jakarta.ws.rs.POST; +import jakarta.ws.rs.Path; +import org.chappiebot.rag.RetrievalProvider; + +import java.util.List; +import java.util.Objects; + + +@Path("/api/search") +public class SearchEndpoint { + + @Inject + RetrievalProvider retrievalProvider; + + @POST + public SearchResponse search(SearchRequest query) { + Log.info("Search request: " + query.queryMessage()); + String queryMessage = query.queryMessage(); + int maxResults = Objects.requireNonNullElse(query.maxResults(), retrievalProvider.getRagMaxResults()); + String restrictToExtension = query.extension(); + + List search = retrievalProvider.search(queryMessage, maxResults, restrictToExtension); + return new SearchResponse(search); + } + +} diff --git a/src/main/java/org/chappiebot/search/SearchMatch.java b/src/main/java/org/chappiebot/search/SearchMatch.java new file mode 100644 index 0000000..c0e8d47 --- /dev/null +++ b/src/main/java/org/chappiebot/search/SearchMatch.java @@ -0,0 +1,6 @@ +package org.chappiebot.search; + +import java.util.Map; + +public record SearchMatch(String text, String source, double score, Map metadata) { +} diff --git a/src/main/java/org/chappiebot/search/SearchRequest.java b/src/main/java/org/chappiebot/search/SearchRequest.java new file mode 100644 index 0000000..38d9c1a --- /dev/null +++ b/src/main/java/org/chappiebot/search/SearchRequest.java @@ -0,0 +1,4 @@ +package org.chappiebot.search; + +public record SearchRequest(String queryMessage, Integer maxResults, String extension) { +} diff --git a/src/main/java/org/chappiebot/search/SearchResponse.java b/src/main/java/org/chappiebot/search/SearchResponse.java new file mode 100644 index 0000000..95b5d1b --- /dev/null +++ b/src/main/java/org/chappiebot/search/SearchResponse.java @@ -0,0 +1,6 @@ +package org.chappiebot.search; + +import java.util.List; + +public record SearchResponse(List results) { +} diff --git a/src/main/resources/application-dev.properties b/src/main/resources/application-dev.properties new file mode 100755 index 0000000..b4575aa --- /dev/null +++ b/src/main/resources/application-dev.properties @@ -0,0 +1,2 @@ +# Use the latest ingestion Postgre dev service image +quarkus.datasource.devservices.image-name=ghcr.io/quarkusio/chappie-ingestion-quarkus:latest