|
1 | 1 | /* |
2 | | - * Copyright © 2024-2025 Apple Inc. and the Pkl project authors. All rights reserved. |
| 2 | + * Copyright © 2024-2026 Apple Inc. and the Pkl project authors. All rights reserved. |
3 | 3 | * |
4 | 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | 5 | * you may not use this file except in compliance with the License. |
|
15 | 15 | */ |
16 | 16 | package org.pkl.server |
17 | 17 |
|
| 18 | +import java.io.PipedInputStream |
| 19 | +import java.io.PipedOutputStream |
18 | 20 | import java.net.URI |
19 | 21 | import java.nio.file.Path |
| 22 | +import java.util.concurrent.CountDownLatch |
20 | 23 | import java.util.concurrent.ExecutorService |
21 | 24 | import java.util.concurrent.Executors |
| 25 | +import java.util.concurrent.atomic.AtomicInteger |
22 | 26 | import kotlin.io.path.createDirectories |
23 | 27 | import kotlin.io.path.outputStream |
24 | 28 | import kotlin.io.path.writeText |
@@ -839,6 +843,100 @@ abstract class AbstractServerTest { |
839 | 843 | ) |
840 | 844 | } |
841 | 845 |
|
| 846 | + /** |
| 847 | + * Regression test for concurrent message encoding. |
| 848 | + * |
| 849 | + * The pkl server's main thread sends [CreateEvaluatorResponse] while the executor thread sends |
| 850 | + * [EvaluateResponse]. Without synchronization on the encoder, these writes interleave on the |
| 851 | + * output stream, corrupting the MessagePack framing. |
| 852 | + * |
| 853 | + * This test exercises the race directly: two threads write different message types through the |
| 854 | + * same [ServerMessagePackEncoder] into a pipe, and a reader thread decodes every message. Any |
| 855 | + * interleaved write produces a decode error. |
| 856 | + * |
| 857 | + * Only meaningful with `USE_DIRECT_TRANSPORT = false` (the default). |
| 858 | + */ |
| 859 | + @Test |
| 860 | + fun `concurrent encoding -- multiple evaluators with module reads`() { |
| 861 | + if (USE_DIRECT_TRANSPORT) return |
| 862 | + |
| 863 | + val pipeIn = PipedInputStream(1 shl 20) // 1 MB buffer |
| 864 | + val pipeOut = PipedOutputStream(pipeIn) |
| 865 | + val encoder = ServerMessagePackEncoder(pipeOut) |
| 866 | + val decoder = ServerMessagePackDecoder(pipeIn) |
| 867 | + |
| 868 | + val iterations = 2000 |
| 869 | + val padding = ByteArray(8192) // large payload to widen the race window |
| 870 | + val errors = mutableListOf<Throwable>() |
| 871 | + val decoded = AtomicInteger(0) |
| 872 | + val done = CountDownLatch(2) |
| 873 | + |
| 874 | + // Writer A: CreateEvaluatorResponse (small messages) |
| 875 | + val writerA = Thread { |
| 876 | + try { |
| 877 | + for (i in 0 until iterations) { |
| 878 | + encoder.encode(CreateEvaluatorResponse(i.toLong(), i.toLong(), null)) |
| 879 | + } |
| 880 | + } catch (e: Exception) { |
| 881 | + synchronized(errors) { errors.add(e) } |
| 882 | + } finally { |
| 883 | + done.countDown() |
| 884 | + } |
| 885 | + } |
| 886 | + |
| 887 | + // Writer B: EvaluateResponse (large messages with 8 KB payload) |
| 888 | + val writerB = Thread { |
| 889 | + try { |
| 890 | + for (i in 0 until iterations) { |
| 891 | + encoder.encode(EvaluateResponse(i.toLong() + iterations, i.toLong(), padding, null)) |
| 892 | + } |
| 893 | + } catch (e: Exception) { |
| 894 | + synchronized(errors) { errors.add(e) } |
| 895 | + } finally { |
| 896 | + done.countDown() |
| 897 | + } |
| 898 | + } |
| 899 | + |
| 900 | + // Reader: decode all messages, check each is well-formed. |
| 901 | + val reader = Thread { |
| 902 | + try { |
| 903 | + while (decoded.get() < iterations * 2) { |
| 904 | + val msg = decoder.decode() ?: break |
| 905 | + decoded.incrementAndGet() |
| 906 | + when (msg) { |
| 907 | + is CreateEvaluatorResponse -> {} |
| 908 | + is EvaluateResponse -> {} |
| 909 | + else -> |
| 910 | + synchronized(errors) { |
| 911 | + errors.add(AssertionError("Wrong message type: ${msg.javaClass.simpleName}")) |
| 912 | + } |
| 913 | + } |
| 914 | + } |
| 915 | + } catch (e: Exception) { |
| 916 | + synchronized(errors) { errors.add(e) } |
| 917 | + } |
| 918 | + } |
| 919 | + |
| 920 | + reader.start() |
| 921 | + writerA.start() |
| 922 | + writerB.start() |
| 923 | + |
| 924 | + done.await(30, java.util.concurrent.TimeUnit.SECONDS) |
| 925 | + pipeOut.close() |
| 926 | + reader.join(10_000) |
| 927 | + |
| 928 | + synchronized(errors) { |
| 929 | + if (errors.isNotEmpty()) { |
| 930 | + throw AssertionError( |
| 931 | + "${errors.size} encoding errors (decoded ${decoded.get()}/${iterations * 2}): " + |
| 932 | + errors.first().message, |
| 933 | + errors.first(), |
| 934 | + ) |
| 935 | + } |
| 936 | + } |
| 937 | + assertThat(decoded.get()).isEqualTo(iterations * 2) |
| 938 | + } |
| 939 | + |
842 | 940 | @Test |
843 | 941 | fun `evaluate with project dependencies`(@TempDir tempDir: Path) { |
844 | 942 | val cacheDir = tempDir.resolve("cache").createDirectories() |
|
0 commit comments