Skip to content

Commit a8a2575

Browse files
committed
.
1 parent ee32622 commit a8a2575

14 files changed

Lines changed: 414 additions & 407 deletions

File tree

core/exec/src/mill/exec/ExecutionContexts.scala

Lines changed: 68 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,62 @@ import mill.api.Logger
1111
import mill.api.daemon.internal.NonFatal
1212

1313
object ExecutionContexts {
14+
private object RunnablePriority {
15+
private val priorityMethod = new ClassValue[Option[java.lang.reflect.Method]] {
16+
override def computeValue(clazz: Class[?]): Option[java.lang.reflect.Method] =
17+
try {
18+
val method = clazz.getMethod("priority")
19+
method.trySetAccessible()
20+
Some(method)
21+
} catch {
22+
case _: ReflectiveOperationException => None
23+
}
24+
}
25+
26+
def apply(runnable: Runnable): Int =
27+
priorityMethod.get(runnable.getClass) match {
28+
case Some(method) => method.invoke(runnable).asInstanceOf[Int]
29+
case None => 0
30+
}
31+
}
32+
33+
private final class QueuedRunnable(
34+
runnable: Runnable,
35+
val priority: Int,
36+
val submissionIndex: Long
37+
) extends Runnable
38+
with Comparable[QueuedRunnable] {
39+
def run(): Unit = runnable.run()
40+
41+
override def compareTo(other: QueuedRunnable): Int =
42+
priority.compareTo(other.priority) match {
43+
case 0 => submissionIndex.compareTo(other.submissionIndex)
44+
case n => n
45+
}
46+
}
47+
48+
private final class PriorityThreadPoolExecutor(
49+
threadCount: Int,
50+
threadFactory: ThreadFactory
51+
) extends ThreadPoolExecutor(
52+
threadCount,
53+
threadCount,
54+
60 * 1000,
55+
TimeUnit.SECONDS,
56+
new PriorityBlockingQueue[Runnable](),
57+
threadFactory
58+
) {
59+
private val submissionCount = new java.util.concurrent.atomic.AtomicLong()
60+
61+
override def execute(command: Runnable): Unit =
62+
super.execute(
63+
new QueuedRunnable(
64+
runnable = command,
65+
priority = RunnablePriority(command),
66+
submissionIndex = submissionCount.getAndIncrement()
67+
)
68+
)
69+
}
1470

1571
/**
1672
* Execution context that runs code immediately when scheduled, without
@@ -76,29 +132,15 @@ object ExecutionContexts {
76132
def reportFailure(t: Throwable): Unit = {}
77133
def close(): Unit = executor.shutdown()
78134

79-
val priorityRunnableCount = java.util.concurrent.atomic.AtomicLong()
80-
81135
/**
82136
* Subclass of [[java.lang.Runnable]] that assigns a priority to execute it
83137
*
84138
* Priority 0 is the default priority of all Mill task, priorities <0 can be used to
85139
* prioritize this runnable over most other tasks, while priorities >0 can be used to
86140
* de-prioritize it.
87141
*/
88-
class PriorityRunnable(val priority: Int, run0: () => Unit) extends Runnable
89-
with Comparable[PriorityRunnable] {
90-
def run() = run0()
91-
val priorityRunnableIndex: Long = priorityRunnableCount.getAndIncrement()
92-
override def compareTo(o: PriorityRunnable): Int = priority.compareTo(o.priority) match {
93-
case 0 =>
94-
// `Comparable` wants a *total* ordering, so we need to use `priorityRunnableIndex`
95-
// to break ties between instances with the same priority. This index is assigned
96-
// when a task is submitted, so it should more or less follow insertion order,
97-
// and is a `Long` which should be big enough never to overflow
98-
assert(this == o || this.priorityRunnableIndex != o.priorityRunnableIndex)
99-
this.priorityRunnableIndex.compareTo(o.priorityRunnableIndex)
100-
case n => n
101-
}
142+
class PriorityRunnable(val priority: Int, run0: () => Unit) extends Runnable {
143+
def run(): Unit = run0()
102144
}
103145

104146
/**
@@ -151,25 +193,15 @@ object ExecutionContexts {
151193
def createExecutor(threadCount: Int): ThreadPoolExecutor = {
152194
val executorIndex = executorCounter.incrementAndGet()
153195
val threadCounter = new AtomicInteger
154-
new ThreadPoolExecutor(
155-
threadCount,
156-
threadCount,
157-
60 * 1000,
158-
TimeUnit.SECONDS,
159-
// Use a `Deque` rather than a normal `Queue`, with the various `poll`/`take`
160-
// operations reversed, providing elements in a LIFO order. This ensures that
161-
// child `fork.async` tasks always take priority over parent tasks, avoiding
162-
// large numbers of blocked parent tasks from piling up
163-
new PriorityBlockingQueue[Runnable](),
164-
runnable => {
165-
val threadIndex = threadCounter.incrementAndGet()
166-
val t = new Thread(
167-
runnable,
168-
s"execution-contexts-threadpool-$executorIndex-thread-$threadIndex"
169-
)
170-
t.setDaemon(true)
171-
t
172-
}
173-
)
196+
val threadFactory: ThreadFactory = runnable => {
197+
val threadIndex = threadCounter.incrementAndGet()
198+
val t = new Thread(
199+
runnable,
200+
s"execution-contexts-threadpool-$executorIndex-thread-$threadIndex"
201+
)
202+
t.setDaemon(true)
203+
t
204+
}
205+
new PriorityThreadPoolExecutor(threadCount, threadFactory)
174206
}
175207
}

core/exec/test/src/mill/exec/ExecutionTests.scala

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ import mill.testkit.{TestRootModule, UnitTester}
66
import mill.{PathRef, exec}
77
import utest.*
88

9+
import java.net.URLClassLoader
10+
import java.util.concurrent.{CountDownLatch, TimeUnit}
11+
912
object ExecutionTests extends TestSuite {
1013
object traverseBuild extends TestRootModule {
1114
trait TaskModule extends mill.Module {
@@ -88,6 +91,26 @@ object ExecutionTests extends TestSuite {
8891
class Checker[T <: mill.testkit.TestRootModule](module: T)
8992
extends exec.Checker(module)
9093

94+
private final class ChildFirstExecutionContextsLoader(urls: Array[java.net.URL], parent: ClassLoader)
95+
extends URLClassLoader(urls, parent) {
96+
override def loadClass(name: String, resolve: Boolean): Class[?] = synchronized {
97+
val alreadyLoaded = findLoadedClass(name)
98+
val loaded =
99+
if (alreadyLoaded != null) alreadyLoaded
100+
else if (name == "mill.exec.ExecutionContexts$" || name.startsWith(
101+
"mill.exec.ExecutionContexts$"
102+
)) {
103+
try findClass(name)
104+
catch {
105+
case _: ClassNotFoundException => super.loadClass(name, false)
106+
}
107+
} else super.loadClass(name, false)
108+
109+
if (resolve) resolveClass(loaded)
110+
loaded
111+
}
112+
}
113+
91114
val tests = Tests {
92115
import TestGraphs.*
93116
import utest.*
@@ -568,6 +591,52 @@ object ExecutionTests extends TestSuite {
568591
assert(res.executionResults.transitiveFailing.keySet == Set(anonTaskFailure.task))
569592
}
570593
}
594+
595+
test("threadPoolSupportsMixedClassloaderPriorityRunnables") {
596+
val executor = ExecutionContexts.createExecutor(1)
597+
val localPool = new ExecutionContexts.ThreadPool(executor)
598+
val codeSourceUrl = ExecutionContexts.getClass.getProtectionDomain.getCodeSource.getLocation
599+
val childLoader =
600+
new ChildFirstExecutionContextsLoader(Array(codeSourceUrl), ExecutionTests.getClass.getClassLoader)
601+
602+
val blockerStarted = new CountDownLatch(1)
603+
val unblock = new CountDownLatch(1)
604+
val completed = new CountDownLatch(2)
605+
606+
try {
607+
val threadPoolClass = childLoader.loadClass("mill.exec.ExecutionContexts$ThreadPool")
608+
val ctor = threadPoolClass.getDeclaredConstructors.head
609+
ctor.trySetAccessible()
610+
val childPool = ctor
611+
.newInstance(executor)
612+
.asInstanceOf[mill.api.TaskCtx.Fork.Impl]
613+
614+
localPool.execute(new Runnable {
615+
override def run(): Unit = {
616+
blockerStarted.countDown()
617+
unblock.await(30, TimeUnit.SECONDS)
618+
}
619+
})
620+
621+
assert(blockerStarted.await(30, TimeUnit.SECONDS))
622+
623+
childPool.execute(new Runnable {
624+
override def run(): Unit = completed.countDown()
625+
})
626+
localPool.execute(new Runnable {
627+
override def run(): Unit = completed.countDown()
628+
})
629+
630+
unblock.countDown()
631+
assert(completed.await(30, TimeUnit.SECONDS))
632+
} finally {
633+
unblock.countDown()
634+
localPool.close()
635+
assert(executor.awaitTermination(30, TimeUnit.SECONDS))
636+
childLoader.close()
637+
}
638+
}
639+
571640
test("overloaded") {
572641
UnitTester(overloads, null).scoped { tester =>
573642
val res = tester.apply(Seq(overloads.overloaded(1)))

integration/bsp-util/src/BspServerTestUtil.scala

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ object BspServerTestUtil {
276276
try proc.stdout.close()
277277
catch { case _: java.io.IOException => () }
278278

279-
proc.join(30000L)
279+
waitForProcessExit(proc, 30000L)
280280
} finally {
281281
if (!success && isCI) {
282282
System.err.println(" == BSP server output ==")
@@ -287,6 +287,32 @@ object BspServerTestUtil {
287287
}
288288
}
289289

290+
private def waitForProcessExit(proc: os.SubProcess, timeoutMillis: Long): Unit = {
291+
val waitIntervalMillis = 50L
292+
val deadlineNanos = System.nanoTime() + timeoutMillis * 1000000L
293+
294+
while (proc.isAlive() && System.nanoTime() < deadlineNanos) {
295+
Thread.sleep(waitIntervalMillis)
296+
}
297+
298+
if (proc.isAlive()) {
299+
// `os.SubProcess.join(timeout)` falls back to recursive destruction on
300+
// timeout, which uses `ProcessHandle.children()` and is not permitted in
301+
// this macOS sandbox. BSP launchers are top-level processes for these
302+
// tests, so a non-recursive destroy is sufficient here.
303+
proc.destroy(recursive = false)
304+
305+
val shutdownDeadlineNanos = System.nanoTime() + 5000L * 1000000L
306+
while (proc.isAlive() && System.nanoTime() < shutdownDeadlineNanos) {
307+
Thread.sleep(waitIntervalMillis)
308+
}
309+
310+
if (proc.isAlive()) {
311+
throw new RuntimeException("BSP server did not exit within the expected timeout")
312+
}
313+
}
314+
}
315+
290316
lazy val millWorkspace: os.Path = {
291317
val value = Option(System.getenv("MILL_PROJECT_ROOT")).getOrElse(???)
292318
os.Path(value)

integration/dedicated/bsp-server-error/src/BspServerErrorTests.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,17 @@ object BspServerErrorTests extends UtestIntegrationTestSuite {
4545
bspLog = Some((bytes, len) => stderr.write(bytes, 0, len))
4646
) { (buildServer, initRes) =>
4747

48+
val firstServerDeadline = System.nanoTime() + 5000L * 1000000L
49+
while (firstServerProc.isAlive() && System.nanoTime() < firstServerDeadline) {
50+
Thread.sleep(50L)
51+
}
4852
assert(!firstServerProc.isAlive())
4953

5054
val firstServerStderrStr = new String(firstServerStderr.toByteArray)
51-
assert(firstServerStderrStr.contains("Received SIGTERM, exiting"))
55+
assert(firstServerStderrStr.contains("BSP shutdown asked by client, exiting"))
5256

5357
val currentStderrStr = new String(stderr.toByteArray)
54-
assert(currentStderrStr.contains("Sent SIGTERM to process"))
58+
assert(currentStderrStr.contains("Asked the active BSP session for 'Mill_Integration' to shut down"))
5559

5660
assert(initRes.getCapabilities.getInverseSourcesProvider == true)
5761

0 commit comments

Comments
 (0)