Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion lmos-runtime-bom/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies {
val springBootVersion = "3.5.6"
val ktorVersion = "3.3.0"
val kotlinxVersion = "1.9.0"
val lmosRouterVersion = "0.21.0"
val lmosRouterVersion = "0.22.0-M1"
val arcVersion = "0.174.0"
val langChain4jVersion = "1.5.0"
val kotlinCoroutines = "1.10.2"
Expand All @@ -33,6 +33,7 @@ dependencies {
api("org.eclipse.lmos:lmos-classifier-vector-spring-boot-starter:$lmosRouterVersion")
api("org.eclipse.lmos:lmos-classifier-hybrid-spring-boot-starter:$lmosRouterVersion")
api("org.eclipse.lmos:lmos-classifier-core:$lmosRouterVersion")
api("org.eclipse.lmos:lmos-classifier-llm:$lmosRouterVersion")
api("org.eclipse.lmos:lmos-router-llm:$lmosRouterVersion")
// arcVersion-managed
api("org.eclipse.lmos:arc-agent-client:$arcVersion")
Expand Down
1 change: 1 addition & 0 deletions lmos-runtime-core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies {
api("org.eclipse.lmos:arc-agent-client")
api("org.eclipse.lmos:arc-api")
api("org.eclipse.lmos:lmos-classifier-core")
api("org.eclipse.lmos:lmos-classifier-llm")

implementation("io.ktor:ktor-client-cio")
implementation("io.ktor:ktor-serialization-kotlinx-json")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,21 @@
package org.eclipse.lmos.runtime.core.disambiguation

import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper
import com.fasterxml.jackson.module.kotlin.readValue
import dev.langchain4j.data.message.AiMessage
import dev.langchain4j.data.message.ChatMessage
import dev.langchain4j.data.message.SystemMessage
import dev.langchain4j.data.message.UserMessage
import dev.langchain4j.model.chat.ChatModel
import dev.langchain4j.model.chat.request.ChatRequest
import dev.langchain4j.model.chat.request.ResponseFormat
import dev.langchain4j.model.chat.request.ResponseFormatType
import dev.langchain4j.model.chat.response.ChatResponse
import dev.langchain4j.service.output.JsonSchemas
import org.eclipse.lmos.classifier.core.Agent
import org.eclipse.lmos.classifier.core.tracing.ClassifierTracer
import org.eclipse.lmos.classifier.core.tracing.NoopClassifierTracer
import org.eclipse.lmos.classifier.llm.OpenInferenceTags
import org.eclipse.lmos.runtime.core.model.Conversation
import org.slf4j.LoggerFactory

Expand All @@ -31,7 +39,7 @@ interface DisambiguationHandler {
* @param candidateAgents agents with the highest match scores
* @return the disambiguation result
*/
fun disambiguate(
suspend fun disambiguate(
conversation: Conversation,
candidateAgents: List<Agent>,
): DisambiguationResult
Expand All @@ -41,28 +49,49 @@ class DefaultDisambiguationHandler(
private val chatModel: ChatModel,
private val introductionPrompt: String,
private val clarificationPrompt: String,
private val tracer: ClassifierTracer = NoopClassifierTracer(),
) : DisambiguationHandler {
private val logger = LoggerFactory.getLogger(javaClass)
private val jacksonObjectMapper = jacksonObjectMapper()
private val responseFormat =
ResponseFormat
.builder()
.type(ResponseFormatType.JSON)
.jsonSchema(JsonSchemas.jsonSchemaFrom(DisambiguationResult::class.java).get())
.build()

override fun disambiguate(
override suspend fun disambiguate(
conversation: Conversation,
candidateAgents: List<Agent>,
): DisambiguationResult {
): DisambiguationResult =
tracer.withSpan("llm") { tags ->
val chatRequest = prepareChatRequest(conversation, candidateAgents)
val chatResponse = chatModel.chat(chatRequest)
val disambiguationResult = prepareDisambiguationResult(chatResponse)
OpenInferenceTags.applyModelTracingTags(tags, chatRequest, chatResponse)
logger
.atDebug()
.addKeyValue("result", disambiguationResult)
.addKeyValue("event", "DISAMBIGUATION_DONE")
.log("Executed disambiguation.")

disambiguationResult
}

private fun prepareChatRequest(
conversation: Conversation,
candidateAgents: List<Agent>,
): ChatRequest {
val disambiguationMessages = mutableListOf<ChatMessage>()
disambiguationMessages.add(prepareIntroductionSystemMessage())
disambiguationMessages.addAll(prepareChatMessages(conversation))
disambiguationMessages.add(prepareClarificationSystemMessage(candidateAgents))

val chatResponse = chatModel.chat(disambiguationMessages)
val disambiguationResult = prepareDisambiguationResult(chatResponse)
logger
.atDebug()
.addKeyValue("result", disambiguationResult)
.addKeyValue("event", "DISAMBIGUATION_DONE")
.log("Executed disambiguation.")

return disambiguationResult
return ChatRequest
.builder()
.responseFormat(responseFormat)
.messages(disambiguationMessages)
.build()
}

private fun prepareIntroductionSystemMessage() = SystemMessage(introductionPrompt)
Expand Down Expand Up @@ -96,7 +125,7 @@ class DefaultDisambiguationHandler(
?: throw IllegalStateException("Disambiguation response is empty or null.")

return try {
jacksonObjectMapper.readValue(json, DisambiguationResult::class.java)
jacksonObjectMapper.readValue<DisambiguationResult>(json)
} catch (ex: Exception) {
logger.error("Failed to parse disambiguation result, JSON: $json", ex)
throw IllegalArgumentException("Invalid disambiguation result format.", ex)
Expand All @@ -105,9 +134,9 @@ class DefaultDisambiguationHandler(
}

data class DisambiguationResult(
val topics: List<String>,
val topics: List<String>?,
val reasoning: String,
val onlyConfirmation: Boolean,
val confidence: Int,
val confidence: Int?,
val clarificationQuestion: String,
)
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
package org.eclipse.lmos.runtime.core.disambiguation

import dev.langchain4j.data.message.AiMessage
import dev.langchain4j.data.message.ChatMessage
import dev.langchain4j.data.message.SystemMessage
import dev.langchain4j.data.message.UserMessage
import dev.langchain4j.model.chat.ChatModel
import dev.langchain4j.model.chat.request.ChatRequest
import dev.langchain4j.model.chat.response.ChatResponse
import dev.langchain4j.model.output.TokenUsage
import io.mockk.every
import io.mockk.mockk
import io.mockk.slot
import kotlinx.coroutines.runBlocking
import org.eclipse.lmos.arc.api.Message
import org.eclipse.lmos.classifier.core.Agent
import org.eclipse.lmos.classifier.core.Capability
Expand All @@ -24,6 +26,7 @@ import org.eclipse.lmos.runtime.core.model.SystemContext
import org.eclipse.lmos.runtime.core.model.UserContext
import org.junit.jupiter.api.Assertions.*
import org.junit.jupiter.api.Test
import kotlin.test.assertFailsWith

class DefaultDisambiguationHandlerTest {
private val chatModel = mockk<ChatModel>()
Expand All @@ -38,75 +41,78 @@ class DefaultDisambiguationHandlerTest {
)

@Test
fun `disambiguate prepares chat model messages and returns clarification question correctly`() {
// given
val userMessage = "Hello, I need help with my contract"
val conversation = conversation(userMessage)
val candidateAgents = candidateAgents()
val chatModelResponse = chatResponse(disambiguationJsonResponse())
val messagesSlot = slot<List<ChatMessage>>()
every { chatModel.chat(capture(messagesSlot)) } returns chatModelResponse

// when
val disambiguationResult = underTest.disambiguate(conversation, candidateAgents)

// then ...
// chat model messages were prepared correctly
val messages = messagesSlot.captured
assertEquals(3, messages.size)

assertEquals(messages[0].javaClass, SystemMessage::class.java)
assertEquals(introductionPrompt, (messages[0] as SystemMessage).text())

assertEquals(messages[1].javaClass, UserMessage::class.java)
assertEquals(userMessage, (messages[1] as UserMessage).singleText())

assertEquals(messages[2].javaClass, SystemMessage::class.java)
assertEquals(
"""
Clarify: Topic 'contract-agent-id':
- View contract details
- Cancel a contract
""".trimIndent(),
(messages[2] as SystemMessage).text(),
)
// and clarification question is returned
assertNotNull(disambiguationResult)
assertEquals("Which contract would you like to view?", disambiguationResult.clarificationQuestion)
}
fun `disambiguate prepares chat model messages and returns clarification question correctly`(): Unit =
runBlocking {
// given
val userMessage = "Hello, I need help with my contract"
val conversation = conversation(userMessage)
val candidateAgents = candidateAgents()
val chatModelResponse = chatResponse(disambiguationJsonResponse())
val messagesSlot = slot<ChatRequest>()
every { chatModel.chat(capture(messagesSlot)) } returns chatModelResponse

// when
val disambiguationResult = underTest.disambiguate(conversation, candidateAgents)

// then ...
// chat model messages were prepared correctly
val messages = messagesSlot.captured.messages()
assertEquals(3, messages.size)

assertEquals(messages[0].javaClass, SystemMessage::class.java)
assertEquals(introductionPrompt, (messages[0] as SystemMessage).text())

assertEquals(messages[1].javaClass, UserMessage::class.java)
assertEquals(userMessage, (messages[1] as UserMessage).singleText())

assertEquals(messages[2].javaClass, SystemMessage::class.java)
assertEquals(
"""
Clarify: Topic 'contract-agent-id':
- View contract details
- Cancel a contract
""".trimIndent(),
(messages[2] as SystemMessage).text(),
)
// and clarification question is returned
assertNotNull(disambiguationResult)
assertEquals("Which contract would you like to view?", disambiguationResult.clarificationQuestion)
}

@Test
fun `disambiguate throws IllegalStateException when response is null`() {
// given
val conversation = conversation("Whats up?")
val agents = candidateAgents()
val chatResponse = chatResponse(null)
every { chatModel.chat(any<List<ChatMessage>>()) } returns chatResponse

// when
val exception =
assertThrows(IllegalStateException::class.java) {
underTest.disambiguate(conversation, agents)
}

// then
assertEquals("Disambiguation response is empty or null.", exception.message)
}
fun `disambiguate throws IllegalStateException when response is null`(): Unit =
runBlocking {
// given
val conversation = conversation("Whats up?")
val agents = candidateAgents()
val chatResponse = chatResponse(null)
every { chatModel.chat(any<ChatRequest>()) } returns chatResponse

// when
val exception =
assertFailsWith<IllegalStateException> {
underTest.disambiguate(conversation, agents)
}

// then
assertEquals("Disambiguation response is empty or null.", exception.message)
}

@Test
fun `disambiguate throws IllegalArgumentException when JSON response is invalid`() {
val conversation = conversation("Whats up?")
val agents = candidateAgents()
val chatResponse = chatResponse("invalid json")
every { chatModel.chat(any<List<ChatMessage>>()) } returns chatResponse

val exception =
assertThrows(IllegalArgumentException::class.java) {
underTest.disambiguate(conversation, agents)
}

assertTrue(exception.message!!.contains("Invalid disambiguation result format."))
}
fun `disambiguate throws IllegalArgumentException when JSON response is invalid`(): Unit =
runBlocking {
val conversation = conversation("Whats up?")
val agents = candidateAgents()
val chatResponse = chatResponse("invalid json")
every { chatModel.chat(any<ChatRequest>()) } returns chatResponse

val exception =
assertFailsWith<IllegalArgumentException> {
underTest.disambiguate(conversation, agents)
}

assertTrue(exception.message!!.contains("Invalid disambiguation result format."))
}

private fun candidateAgents() =
listOf(
Expand Down Expand Up @@ -135,6 +141,8 @@ class DefaultDisambiguationHandlerTest {
private fun chatResponse(text: String?): ChatResponse? =
ChatResponse
.builder()
.modelName("MyModel")
.tokenUsage(TokenUsage(1, 2, 3))
.aiMessage(AiMessage(text, emptyList()))
.build()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

package org.eclipse.lmos.runtime.core.outbound

import io.mockk.every
import io.mockk.coEvery
import io.mockk.mockk
import kotlinx.coroutines.runBlocking
import org.eclipse.lmos.arc.api.Message
Expand Down Expand Up @@ -44,7 +44,7 @@ class LmosAgentClassifierServiceTest {
val conversation = Conversation(defaultInputContext, defaultSystemContext, defaultUserContext)

var capturedRequest: ClassificationRequest? = null
every { classifierMock.classify(any()) } answers {
coEvery { classifierMock.classify(any()) } answers {
capturedRequest = firstArg()
ClassificationResult(emptyList())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ package org.eclipse.lmos.runtime.config

import dev.langchain4j.model.chat.ChatModel
import org.eclipse.lmos.classifier.core.AgentClassifier
import org.eclipse.lmos.classifier.core.tracing.ClassifierTracer
import org.eclipse.lmos.classifier.core.tracing.NoopClassifierTracer
import org.eclipse.lmos.classifier.llm.ChatModelClientProperties
import org.eclipse.lmos.classifier.llm.LangChainChatModelFactory
import org.eclipse.lmos.runtime.channelrouting.DefaultCachedChannelRoutingRepository
Expand Down Expand Up @@ -136,11 +138,13 @@ class RuntimeAutoConfiguration(
fun disambiguationHandler(
@Qualifier("disambiguationChatModel") chatModel: ChatModel,
lmosRuntimeProperties: RuntimeProperties,
tracerProvider: ObjectProvider<ClassifierTracer>,
): DisambiguationHandler =
DefaultDisambiguationHandler(
chatModel,
lmosRuntimeProperties.disambiguation.introductionPrompt(),
lmosRuntimeProperties.disambiguation.clarificationPrompt(),
tracerProvider.getIfAvailable { NoopClassifierTracer() },
)

@Bean
Expand Down