Skip to content

Commit 8e7eb2b

Browse files
authored
Fix data race in MessagePack encoder for concurrent server sends (#1486)
1 parent 5803359 commit 8e7eb2b

2 files changed

Lines changed: 111 additions & 6 deletions

File tree

pkl-core/src/main/java/org/pkl/core/messaging/AbstractMessagePackEncoder.java

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright © 2024-2025 Apple Inc. and the Pkl project authors. All rights reserved.
2+
* Copyright © 2024-2026 Apple Inc. and the Pkl project authors. All rights reserved.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -41,10 +41,17 @@ public AbstractMessagePackEncoder(OutputStream stream) {
4141

4242
@Override
4343
public final void encode(Message msg) throws IOException, ProtocolException {
44-
packer.packArrayHeader(2);
45-
packer.packInt(msg.type().getCode());
46-
encodeMessage(msg);
47-
packer.flush();
44+
// Serialize access to the packer. In pkl server mode the main thread
45+
// (handling CreateEvaluatorRequest) and the executor thread (sending
46+
// EvaluateResponse / ReadModuleRequest) call encode() concurrently.
47+
// Without this lock their writes interleave, corrupting the MessagePack
48+
// stream. See JvmServerTest "concurrent encoding" for a regression test.
49+
synchronized (packer) {
50+
packer.packArrayHeader(2);
51+
packer.packInt(msg.type().getCode());
52+
encodeMessage(msg);
53+
packer.flush();
54+
}
4855
}
4956

5057
protected void packMapHeader(int size, @Nullable Object value1) throws IOException {

pkl-server/src/test/kotlin/org/pkl/server/AbstractServerTest.kt

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright © 2024-2025 Apple Inc. and the Pkl project authors. All rights reserved.
2+
* Copyright © 2024-2026 Apple Inc. and the Pkl project authors. All rights reserved.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -15,10 +15,14 @@
1515
*/
1616
package org.pkl.server
1717

18+
import java.io.PipedInputStream
19+
import java.io.PipedOutputStream
1820
import java.net.URI
1921
import java.nio.file.Path
22+
import java.util.concurrent.CountDownLatch
2023
import java.util.concurrent.ExecutorService
2124
import java.util.concurrent.Executors
25+
import java.util.concurrent.atomic.AtomicInteger
2226
import kotlin.io.path.createDirectories
2327
import kotlin.io.path.outputStream
2428
import kotlin.io.path.writeText
@@ -839,6 +843,100 @@ abstract class AbstractServerTest {
839843
)
840844
}
841845

846+
/**
847+
* Regression test for concurrent message encoding.
848+
*
849+
* The pkl server's main thread sends [CreateEvaluatorResponse] while the executor thread sends
850+
* [EvaluateResponse]. Without synchronization on the encoder, these writes interleave on the
851+
* output stream, corrupting the MessagePack framing.
852+
*
853+
* This test exercises the race directly: two threads write different message types through the
854+
* same [ServerMessagePackEncoder] into a pipe, and a reader thread decodes every message. Any
855+
* interleaved write produces a decode error.
856+
*
857+
* Only meaningful with `USE_DIRECT_TRANSPORT = false` (the default).
858+
*/
859+
@Test
860+
fun `concurrent encoding -- multiple evaluators with module reads`() {
861+
if (USE_DIRECT_TRANSPORT) return
862+
863+
val pipeIn = PipedInputStream(1 shl 20) // 1 MB buffer
864+
val pipeOut = PipedOutputStream(pipeIn)
865+
val encoder = ServerMessagePackEncoder(pipeOut)
866+
val decoder = ServerMessagePackDecoder(pipeIn)
867+
868+
val iterations = 2000
869+
val padding = ByteArray(8192) // large payload to widen the race window
870+
val errors = mutableListOf<Throwable>()
871+
val decoded = AtomicInteger(0)
872+
val done = CountDownLatch(2)
873+
874+
// Writer A: CreateEvaluatorResponse (small messages)
875+
val writerA = Thread {
876+
try {
877+
for (i in 0 until iterations) {
878+
encoder.encode(CreateEvaluatorResponse(i.toLong(), i.toLong(), null))
879+
}
880+
} catch (e: Exception) {
881+
synchronized(errors) { errors.add(e) }
882+
} finally {
883+
done.countDown()
884+
}
885+
}
886+
887+
// Writer B: EvaluateResponse (large messages with 8 KB payload)
888+
val writerB = Thread {
889+
try {
890+
for (i in 0 until iterations) {
891+
encoder.encode(EvaluateResponse(i.toLong() + iterations, i.toLong(), padding, null))
892+
}
893+
} catch (e: Exception) {
894+
synchronized(errors) { errors.add(e) }
895+
} finally {
896+
done.countDown()
897+
}
898+
}
899+
900+
// Reader: decode all messages, check each is well-formed.
901+
val reader = Thread {
902+
try {
903+
while (decoded.get() < iterations * 2) {
904+
val msg = decoder.decode() ?: break
905+
decoded.incrementAndGet()
906+
when (msg) {
907+
is CreateEvaluatorResponse -> {}
908+
is EvaluateResponse -> {}
909+
else ->
910+
synchronized(errors) {
911+
errors.add(AssertionError("Wrong message type: ${msg.javaClass.simpleName}"))
912+
}
913+
}
914+
}
915+
} catch (e: Exception) {
916+
synchronized(errors) { errors.add(e) }
917+
}
918+
}
919+
920+
reader.start()
921+
writerA.start()
922+
writerB.start()
923+
924+
done.await(30, java.util.concurrent.TimeUnit.SECONDS)
925+
pipeOut.close()
926+
reader.join(10_000)
927+
928+
synchronized(errors) {
929+
if (errors.isNotEmpty()) {
930+
throw AssertionError(
931+
"${errors.size} encoding errors (decoded ${decoded.get()}/${iterations * 2}): " +
932+
errors.first().message,
933+
errors.first(),
934+
)
935+
}
936+
}
937+
assertThat(decoded.get()).isEqualTo(iterations * 2)
938+
}
939+
842940
@Test
843941
fun `evaluate with project dependencies`(@TempDir tempDir: Path) {
844942
val cacheDir = tempDir.resolve("cache").createDirectories()

0 commit comments

Comments
 (0)