Skip to content

Update to Kotlin 2.1.20 and minor refactoring #78

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ plugins {
alias(libs.plugins.kotlin.serialization)
alias(libs.plugins.dokka)
alias(libs.plugins.jreleaser)
alias(libs.plugins.atomicfu)
`maven-publish`
alias(libs.plugins.kotlinx.binary.compatibility.validator)
}
Expand Down
4 changes: 1 addition & 3 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
[versions]
# plugins version
kotlin = "2.0.21"
kotlin = "2.1.20"
dokka = "2.0.0"
atomicfu = "0.26.1"

# libraries version
serialization = "1.7.3"
Expand Down Expand Up @@ -40,5 +39,4 @@ kotlin-multiplatform = { id = "org.jetbrains.kotlin.multiplatform", version.ref
kotlin-serialization = { id = "org.jetbrains.kotlin.plugin.serialization", version.ref = "kotlin" }
dokka = { id = "org.jetbrains.dokka", version.ref = "dokka" }
jreleaser = { id = "org.jreleaser", version.ref = "jreleaser"}
atomicfu = { id = "org.jetbrains.kotlinx.atomicfu", version.ref = "atomicfu" }
kotlinx-binary-compatibility-validator = { id = "org.jetbrains.kotlinx.binary-compatibility-validator", version.ref = "binaryCompatibilityValidatorPlugin" }
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ import io.ktor.http.*
import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage
import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
import io.modelcontextprotocol.kotlin.sdk.shared.McpJson
import kotlinx.atomicfu.AtomicBoolean
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.*
import kotlinx.serialization.encodeToString
import kotlin.concurrent.atomics.AtomicBoolean
import kotlin.concurrent.atomics.ExperimentalAtomicApi
import kotlin.properties.Delegates
import kotlin.time.Duration

Expand All @@ -22,6 +22,7 @@ public typealias SSEClientTransport = SseClientTransport
* Client transport for SSE: this will connect to a server using Server-Sent Events for receiving
* messages and make separate POST requests for sending messages.
*/
@OptIn(ExperimentalAtomicApi::class)
public class SseClientTransport(
private val client: HttpClient,
private val urlString: String?,
Expand All @@ -32,7 +33,7 @@ public class SseClientTransport(
CoroutineScope(session.coroutineContext + SupervisorJob())
}

private val initialized: AtomicBoolean = atomic(false)
private val initialized: AtomicBoolean = AtomicBoolean(false)
private var session: ClientSSESession by Delegates.notNull()
private val endpoint = CompletableDeferred<String>()

Expand Down Expand Up @@ -127,7 +128,7 @@ public class SseClientTransport(
}

override suspend fun close() {
if (!initialized.value) {
if (!initialized.load()) {
error("SSEClientTransport is not initialized!")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,12 @@ import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage
import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
import io.modelcontextprotocol.kotlin.sdk.shared.ReadBuffer
import io.modelcontextprotocol.kotlin.sdk.shared.serializeMessage
import kotlinx.atomicfu.AtomicBoolean
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.consumeEach
import kotlinx.io.Buffer
import kotlinx.io.Sink
import kotlinx.io.Source
import kotlinx.io.buffered
import kotlinx.io.readByteArray
import kotlinx.io.writeString
import kotlinx.io.*
import kotlin.concurrent.atomics.AtomicBoolean
import kotlin.concurrent.atomics.ExperimentalAtomicApi
import kotlin.coroutines.CoroutineContext

/**
Expand All @@ -27,6 +22,7 @@ import kotlin.coroutines.CoroutineContext
* @param input The input stream where messages are received.
* @param output The output stream where messages are sent.
*/
@OptIn(ExperimentalAtomicApi::class)
public class StdioClientTransport(
private val input: Source,
private val output: Sink
Expand All @@ -37,7 +33,7 @@ public class StdioClientTransport(
CoroutineScope(ioCoroutineContext + SupervisorJob())
}
private var job: Job? = null
private val initialized: AtomicBoolean = atomic(false)
private val initialized: AtomicBoolean = AtomicBoolean(false)
private val sendChannel = Channel<JSONRPCMessage>(Channel.UNLIMITED)
private val readBuffer = ReadBuffer()

Expand Down Expand Up @@ -96,7 +92,7 @@ public class StdioClientTransport(
}

override suspend fun send(message: JSONRPCMessage) {
if (!initialized.value) {
if (!initialized.load()) {
error("Transport not started")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ import io.ktor.server.sse.*
import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage
import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
import io.modelcontextprotocol.kotlin.sdk.shared.McpJson
import kotlinx.atomicfu.AtomicBoolean
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.job
import kotlinx.serialization.encodeToString
import kotlin.concurrent.atomics.AtomicBoolean
import kotlin.concurrent.atomics.ExperimentalAtomicApi
import kotlin.uuid.ExperimentalUuidApi
import kotlin.uuid.Uuid

Expand All @@ -25,11 +25,12 @@ public typealias SSEServerTransport = SseServerTransport
*
* Creates a new SSE server transport, which will direct the client to POST messages to the relative or absolute URL identified by `_endpoint`.
*/
@OptIn(ExperimentalAtomicApi::class)
public class SseServerTransport(
private val endpoint: String,
private val session: ServerSSESession,
) : AbstractTransport() {
private val initialized: AtomicBoolean = atomic(false)
private val initialized: AtomicBoolean = AtomicBoolean(false)

@OptIn(ExperimentalUuidApi::class)
public val sessionId: String = Uuid.random().toString()
Expand Down Expand Up @@ -63,7 +64,7 @@ public class SseServerTransport(
* This should be called when a POST request is made to send a message to the server.
*/
public suspend fun handlePostMessage(call: ApplicationCall) {
if (!initialized.value) {
if (!initialized.load()) {
val message = "SSE connection not established"
call.respondText(message, status = HttpStatusCode.InternalServerError)
_onError.invoke(IllegalStateException(message))
Expand Down Expand Up @@ -112,7 +113,7 @@ public class SseServerTransport(
}

override suspend fun send(message: JSONRPCMessage) {
if (!initialized.value) {
if (!initialized.load()) {
throw error("Not connected")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,43 +5,38 @@ import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage
import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
import io.modelcontextprotocol.kotlin.sdk.shared.ReadBuffer
import io.modelcontextprotocol.kotlin.sdk.shared.serializeMessage
import kotlinx.atomicfu.AtomicBoolean
import kotlinx.atomicfu.atomic
import kotlinx.atomicfu.locks.ReentrantLock
import kotlinx.atomicfu.locks.withLock
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import kotlinx.io.Buffer
import kotlinx.io.Sink
import kotlinx.io.Source
import kotlinx.io.buffered
import kotlinx.io.readByteArray
import kotlinx.io.writeString
import kotlinx.io.*
import kotlin.concurrent.atomics.AtomicBoolean
import kotlin.concurrent.atomics.ExperimentalAtomicApi
import kotlin.coroutines.CoroutineContext

/**
* A server transport that communicates with a client via standard I/O.
*
* Reads from System.in and writes to System.out.
*/
@OptIn(ExperimentalAtomicApi::class)
public class StdioServerTransport(
private val inputStream: Source, //BufferedInputStream = BufferedInputStream(System.`in`),
outputStream: Sink //PrintStream = System.out
private val inputStream: Source,
outputStream: Sink
) : AbstractTransport() {
private val logger = KotlinLogging.logger {}

private val readBuffer = ReadBuffer()
private val initialized: AtomicBoolean = atomic(false)
private val initialized: AtomicBoolean = AtomicBoolean(false)
private var readingJob: Job? = null
private var sendingJob: Job? = null

private val coroutineContext: CoroutineContext = Dispatchers.IO + SupervisorJob()
private val scope = CoroutineScope(coroutineContext)
private val readChannel = Channel<ByteArray>(Channel.UNLIMITED)
private val writeChannel = Channel<JSONRPCMessage>(Channel.UNLIMITED)
private val outputWriter = outputStream.buffered()
private val lock = ReentrantLock()

override suspend fun start() {
if (!initialized.compareAndSet(false, true)) {
if (!initialized.compareAndSet(expectedValue = false, newValue = true)) {
error("StdioServerTransport already started!")
}

Expand Down Expand Up @@ -80,6 +75,20 @@ public class StdioServerTransport(
_onError.invoke(e)
}
}

// Launch a coroutine to handle message sending
sendingJob = scope.launch {
try {
for (message in writeChannel) {
val json = serializeMessage(message)
outputWriter.writeString(json)
outputWriter.flush()
}
} catch (e: Throwable) {
logger.error(e) { "Error writing to stdout" }
_onError.invoke(e)
}
}
}

private suspend fun processReadBuffer() {
Expand All @@ -102,22 +111,20 @@ public class StdioServerTransport(
}

override suspend fun close() {
if (!initialized.compareAndSet(true, false)) return
if (!initialized.compareAndSet(expectedValue = true, newValue = false)) return

// Cancel reading job and close channel
readingJob?.cancel() // ToDO("was cancel and join")
sendingJob?.cancel()

readChannel.close()
writeChannel.close()
readBuffer.clear()

_onClose.invoke()
}

override suspend fun send(message: JSONRPCMessage) {
val json = serializeMessage(message)
lock.withLock {
// You may need to add Content-Length headers before the message if using the LSP framing protocol
outputWriter.writeString(json)
outputWriter.flush()
}
writeChannel.send(message)
}
}
Original file line number Diff line number Diff line change
@@ -1,33 +1,26 @@
package io.modelcontextprotocol.kotlin.sdk.shared

import io.ktor.websocket.Frame
import io.ktor.websocket.WebSocketSession
import io.ktor.websocket.close
import io.ktor.websocket.readText
import io.ktor.websocket.*
import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage
import kotlinx.atomicfu.AtomicBoolean
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.InternalCoroutinesApi
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.ClosedReceiveChannelException
import kotlinx.coroutines.job
import kotlinx.coroutines.launch
import kotlinx.serialization.encodeToString
import kotlin.concurrent.atomics.AtomicBoolean
import kotlin.concurrent.atomics.ExperimentalAtomicApi

internal const val MCP_SUBPROTOCOL = "mcp"

/**
* Abstract class representing a WebSocket transport for the Model Context Protocol (MCP).
* Handles communication over a WebSocket session.
*/
@OptIn(ExperimentalAtomicApi::class)
public abstract class WebSocketMcpTransport : AbstractTransport() {
private val scope by lazy {
CoroutineScope(session.coroutineContext + SupervisorJob())
}

private val initialized: AtomicBoolean = atomic(false)
private val initialized: AtomicBoolean = AtomicBoolean(false)
/**
* The WebSocket session used for communication.
*/
Expand Down Expand Up @@ -83,15 +76,15 @@ public abstract class WebSocketMcpTransport : AbstractTransport() {
}

override suspend fun send(message: JSONRPCMessage) {
if (!initialized.value) {
if (!initialized.load()) {
error("Not connected")
}

session.outgoing.send(Frame.Text(McpJson.encodeToString(message)))
}

override suspend fun close() {
if (!initialized.value) {
if (!initialized.load()) {
error("Not connected")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
package io.modelcontextprotocol.kotlin.sdk

import io.modelcontextprotocol.kotlin.sdk.shared.McpJson
import kotlinx.atomicfu.AtomicLong
import kotlinx.atomicfu.atomic
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.decodeFromJsonElement
import kotlinx.serialization.json.encodeToJsonElement
import kotlin.jvm.JvmInline
import kotlin.concurrent.atomics.AtomicLong
import kotlin.concurrent.atomics.ExperimentalAtomicApi
import kotlin.concurrent.atomics.incrementAndFetch

public const val LATEST_PROTOCOL_VERSION: String = "2024-11-05"

Expand All @@ -21,7 +21,8 @@ public val SUPPORTED_PROTOCOL_VERSIONS: Array<String> = arrayOf(

public const val JSONRPC_VERSION: String = "2.0"

private val REQUEST_MESSAGE_ID: AtomicLong = atomic(0L)
@OptIn(ExperimentalAtomicApi::class)
private val REQUEST_MESSAGE_ID: AtomicLong = AtomicLong(0L)

/**
* A progress token, used to associate progress notifications with the original request.
Expand Down Expand Up @@ -132,7 +133,7 @@ internal fun Request.toJSON(): JSONRPCRequest {
*/
internal fun JSONRPCRequest.fromJSON(): Request? {
val serializer = selectRequestDeserializer(method)
val params = params ?: return null
val params = params
return McpJson.decodeFromJsonElement<Request>(serializer, params)
}

Expand Down Expand Up @@ -211,9 +212,10 @@ public sealed interface JSONRPCMessage
/**
* A request that expects a response.
*/
@OptIn(ExperimentalAtomicApi::class)
@Serializable
public data class JSONRPCRequest(
val id: RequestId = RequestId.NumberId(REQUEST_MESSAGE_ID.incrementAndGet()),
val id: RequestId = RequestId.NumberId(REQUEST_MESSAGE_ID.incrementAndFetch()),
val method: String,
val params: JsonElement = EmptyJsonObject,
val jsonrpc: String = JSONRPC_VERSION,
Expand Down
Loading