@@ -5,43 +5,38 @@ import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage
5
5
import io.modelcontextprotocol.kotlin.sdk.shared.AbstractTransport
6
6
import io.modelcontextprotocol.kotlin.sdk.shared.ReadBuffer
7
7
import io.modelcontextprotocol.kotlin.sdk.shared.serializeMessage
8
- import kotlinx.atomicfu.AtomicBoolean
9
- import kotlinx.atomicfu.atomic
10
- import kotlinx.atomicfu.locks.ReentrantLock
11
- import kotlinx.atomicfu.locks.withLock
12
8
import kotlinx.coroutines.*
13
9
import kotlinx.coroutines.channels.Channel
14
- import kotlinx.io.Buffer
15
- import kotlinx.io.Sink
16
- import kotlinx.io.Source
17
- import kotlinx.io.buffered
18
- import kotlinx.io.readByteArray
19
- import kotlinx.io.writeString
10
+ import kotlinx.io.*
11
+ import kotlin.concurrent.atomics.AtomicBoolean
12
+ import kotlin.concurrent.atomics.ExperimentalAtomicApi
20
13
import kotlin.coroutines.CoroutineContext
21
14
22
15
/* *
23
16
* A server transport that communicates with a client via standard I/O.
24
17
*
25
18
* Reads from System.in and writes to System.out.
26
19
*/
20
+ @OptIn(ExperimentalAtomicApi ::class )
27
21
public class StdioServerTransport (
28
- private val inputStream : Source , // BufferedInputStream = BufferedInputStream(System.`in`),
29
- outputStream : Sink // PrintStream = System.out
22
+ private val inputStream : Source ,
23
+ outputStream : Sink
30
24
) : AbstractTransport() {
31
25
private val logger = KotlinLogging .logger {}
32
26
33
27
private val readBuffer = ReadBuffer ()
34
- private val initialized: AtomicBoolean = atomic (false )
28
+ private val initialized: AtomicBoolean = AtomicBoolean (false )
35
29
private var readingJob: Job ? = null
30
+ private var sendingJob: Job ? = null
36
31
37
32
private val coroutineContext: CoroutineContext = Dispatchers .IO + SupervisorJob ()
38
33
private val scope = CoroutineScope (coroutineContext)
39
34
private val readChannel = Channel <ByteArray >(Channel .UNLIMITED )
35
+ private val writeChannel = Channel <JSONRPCMessage >(Channel .UNLIMITED )
40
36
private val outputWriter = outputStream.buffered()
41
- private val lock = ReentrantLock ()
42
37
43
38
override suspend fun start () {
44
- if (! initialized.compareAndSet(false , true )) {
39
+ if (! initialized.compareAndSet(expectedValue = false , newValue = true )) {
45
40
error(" StdioServerTransport already started!" )
46
41
}
47
42
@@ -80,6 +75,20 @@ public class StdioServerTransport(
80
75
_onError .invoke(e)
81
76
}
82
77
}
78
+
79
+ // Launch a coroutine to handle message sending
80
+ sendingJob = scope.launch {
81
+ try {
82
+ for (message in writeChannel) {
83
+ val json = serializeMessage(message)
84
+ outputWriter.writeString(json)
85
+ outputWriter.flush()
86
+ }
87
+ } catch (e: Throwable ) {
88
+ logger.error(e) { " Error writing to stdout" }
89
+ _onError .invoke(e)
90
+ }
91
+ }
83
92
}
84
93
85
94
private suspend fun processReadBuffer () {
@@ -102,22 +111,20 @@ public class StdioServerTransport(
102
111
}
103
112
104
113
override suspend fun close () {
105
- if (! initialized.compareAndSet(true , false )) return
114
+ if (! initialized.compareAndSet(expectedValue = true , newValue = false )) return
106
115
107
116
// Cancel reading job and close channel
108
117
readingJob?.cancel() // ToDO("was cancel and join")
118
+ sendingJob?.cancel()
119
+
109
120
readChannel.close()
121
+ writeChannel.close()
110
122
readBuffer.clear()
111
123
112
124
_onClose .invoke()
113
125
}
114
126
115
127
override suspend fun send (message : JSONRPCMessage ) {
116
- val json = serializeMessage(message)
117
- lock.withLock {
118
- // You may need to add Content-Length headers before the message if using the LSP framing protocol
119
- outputWriter.writeString(json)
120
- outputWriter.flush()
121
- }
128
+ writeChannel.send(message)
122
129
}
123
130
}
0 commit comments