Skip to content

Commit 0d7ad11

Browse files
committed
feat: Add support for switching scheduler
1 parent 189c893 commit 0d7ad11

File tree

11 files changed

+241
-35
lines changed

11 files changed

+241
-35
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.pekko.dispatch
19+
20+
import com.typesafe.config.ConfigFactory
21+
22+
import org.apache.pekko
23+
import pekko.actor.{ Actor, Props }
24+
import pekko.testkit.{ ImplicitSender, PekkoSpec }
25+
import pekko.util.JavaVersion
26+
27+
object ForkJoinPoolVirtualThreadSpec {
28+
val config = ConfigFactory.parseString("""
29+
|custom {
30+
| task-dispatcher {
31+
| mailbox-type = "org.apache.pekko.dispatch.SingleConsumerOnlyUnboundedMailbox"
32+
| throughput = 5
33+
| fork-join-executor {
34+
| parallelism-factor = 2
35+
| parallelism-max = 2
36+
| parallelism-min = 2
37+
| virtualize = on
38+
| }
39+
| }
40+
|}
41+
""".stripMargin)
42+
43+
class ThreadNameActor extends Actor {
44+
45+
override def receive = {
46+
case "ping" =>
47+
sender() ! Thread.currentThread().getName
48+
}
49+
}
50+
51+
}
52+
53+
class ForkJoinPoolVirtualThreadSpec extends PekkoSpec(ForkJoinPoolVirtualThreadSpec.config) with ImplicitSender {
54+
import ForkJoinPoolVirtualThreadSpec._
55+
56+
"PekkoForkJoinPool" must {
57+
58+
"support virtualization with Virtual Thread" in {
59+
val actor = system.actorOf(Props(new ThreadNameActor).withDispatcher("custom.task-dispatcher"))
60+
for (_ <- 1 to 1000) {
61+
actor ! "ping"
62+
expectMsgPF() { case name: String =>
63+
name should include("ForkJoinPoolVirtualThreadSpec-custom.task-dispatcher-virtual-thread-")
64+
}
65+
}
66+
}
67+
68+
}
69+
}

actor/src/main/resources/reference.conf

+12
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,18 @@ pekko {
487487
# This config is new in Pekko v1.1.0 and only has an effect if you are running with JDK 9 and above.
488488
# Read the documentation on `java.util.concurrent.ForkJoinPool` to find out more. Default in hex is 0x7fff.
489489
maximum-pool-size = 32767
490+
491+
# This config is new in Pekko v1.2.0 and only has an effect if you are running with JDK 21 and above,
492+
# When set to `on` but underlying runtime does not support virtual threads, an Exception will throw.
493+
# Virtualize this dispatcher as a virtual-thread-executor
494+
# Valid values are: `on`, `off`
495+
#
496+
# Requirements:
497+
# 1. JDK 21+
498+
# 2. add options to the JVM:
499+
# --add-opens=java.base/jdk.internal.misc=ALL-UNNAMED
500+
# --add-opens=java.base/java.lang=ALL-UNNAMED
501+
virtualize = off
490502
}
491503

492504
# This will be used if you have set "executor = "thread-pool-executor""

actor/src/main/scala/org/apache/pekko/dispatch/AbstractDispatcher.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ final class VirtualThreadExecutorConfigurator(config: Config, prerequisites: Dis
453453
}
454454
}
455455
new VirtualizedExecutorService(
456-
tf,
456+
tf, // the virtual thread factory
457457
pool, // the default scheduler of virtual thread
458458
loadMetricsProvider,
459459
cascadeShutdown = false // we don't want to cascade shutdown the default virtual thread scheduler

actor/src/main/scala/org/apache/pekko/dispatch/ForkJoinExecutorConfigurator.scala

+66-12
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414
package org.apache.pekko.dispatch
1515

1616
import com.typesafe.config.Config
17+
import org.apache.pekko
18+
import pekko.dispatch.VirtualThreadSupport.newVirtualThreadFactory
19+
import pekko.util.JavaVersion
1720

1821
import java.lang.invoke.{ MethodHandle, MethodHandles, MethodType }
19-
import java.util.concurrent.{ ExecutorService, ForkJoinPool, ForkJoinTask, ThreadFactory }
22+
import java.util.concurrent.{ Executor, ExecutorService, ForkJoinPool, ForkJoinTask, ThreadFactory }
2023
import scala.util.Try
2124

22-
import org.apache.pekko.util.JavaVersion
23-
2425
object ForkJoinExecutorConfigurator {
2526

2627
/**
@@ -86,15 +87,28 @@ class ForkJoinExecutorConfigurator(config: Config, prerequisites: DispatcherPrer
8687
}
8788

8889
class ForkJoinExecutorServiceFactory(
90+
val id: String,
8991
val threadFactory: ForkJoinPool.ForkJoinWorkerThreadFactory,
9092
val parallelism: Int,
9193
val asyncMode: Boolean,
92-
val maxPoolSize: Int)
94+
val maxPoolSize: Int,
95+
val virtualize: Boolean)
9396
extends ExecutorServiceFactory {
97+
def this(threadFactory: ForkJoinPool.ForkJoinWorkerThreadFactory,
98+
parallelism: Int,
99+
asyncMode: Boolean,
100+
maxPoolSize: Int,
101+
virtualize: Boolean) =
102+
this(null, threadFactory, parallelism, asyncMode, maxPoolSize, virtualize)
94103

95104
def this(threadFactory: ForkJoinPool.ForkJoinWorkerThreadFactory,
96105
parallelism: Int,
97-
asyncMode: Boolean) = this(threadFactory, parallelism, asyncMode, ForkJoinPoolConstants.MaxCap)
106+
asyncMode: Boolean) = this(threadFactory, parallelism, asyncMode, ForkJoinPoolConstants.MaxCap, false)
107+
108+
def this(threadFactory: ForkJoinPool.ForkJoinWorkerThreadFactory,
109+
parallelism: Int,
110+
asyncMode: Boolean,
111+
maxPoolSize: Int) = this(threadFactory, parallelism, asyncMode, maxPoolSize, false)
98112

99113
private def pekkoJdk9ForkJoinPoolClassOpt: Option[Class[_]] =
100114
Try(Class.forName("org.apache.pekko.dispatch.PekkoJdk9ForkJoinPool")).toOption
@@ -116,12 +130,50 @@ class ForkJoinExecutorConfigurator(config: Config, prerequisites: DispatcherPrer
116130
def this(threadFactory: ForkJoinPool.ForkJoinWorkerThreadFactory, parallelism: Int) =
117131
this(threadFactory, parallelism, asyncMode = true)
118132

119-
def createExecutorService: ExecutorService = pekkoJdk9ForkJoinPoolHandleOpt match {
120-
case Some(handle) =>
121-
handle.invoke(parallelism, threadFactory, maxPoolSize,
122-
MonitorableThreadFactory.doNothing, asyncMode).asInstanceOf[ExecutorService]
123-
case _ =>
124-
new PekkoForkJoinPool(parallelism, threadFactory, MonitorableThreadFactory.doNothing, asyncMode)
133+
def createExecutorService: ExecutorService = {
134+
val tf = if (virtualize && JavaVersion.majorVersion >= 21) {
135+
threadFactory match {
136+
// we need to use the thread factory to create carrier thread
137+
case m: MonitorableThreadFactory => new MonitorableCarrierThreadFactory(m.name)
138+
case _ => threadFactory
139+
}
140+
} else threadFactory
141+
142+
val pool = pekkoJdk9ForkJoinPoolHandleOpt match {
143+
case Some(handle) =>
144+
// carrier Thread only exists in JDK 17+
145+
handle.invoke(parallelism, tf, maxPoolSize, MonitorableThreadFactory.doNothing, asyncMode)
146+
.asInstanceOf[ExecutorService with LoadMetrics]
147+
case _ =>
148+
new PekkoForkJoinPool(parallelism, tf, MonitorableThreadFactory.doNothing, asyncMode)
149+
}
150+
151+
if (virtualize && JavaVersion.majorVersion >= 21) {
152+
// when virtualized, we need enhanced thread factory
153+
val factory: ThreadFactory = threadFactory match {
154+
case MonitorableThreadFactory(name, _, contextClassLoader, exceptionHandler, _) =>
155+
new ThreadFactory {
156+
private val vtFactory = newVirtualThreadFactory(name, pool) // use the pool as the scheduler
157+
158+
override def newThread(r: Runnable): Thread = {
159+
val vt = vtFactory.newThread(r)
160+
vt.setUncaughtExceptionHandler(exceptionHandler)
161+
contextClassLoader.foreach(vt.setContextClassLoader)
162+
vt
163+
}
164+
}
165+
case _ => newVirtualThreadFactory(prerequisites.settings.name, pool); // use the pool as the scheduler
166+
}
167+
// wrap the pool with virtualized executor service
168+
new VirtualizedExecutorService(
169+
factory, // the virtual thread factory
170+
pool, // the underlying pool
171+
(_: Executor) => pool.atFullThrottle(), // the load metrics provider, we use the pool itself
172+
cascadeShutdown = true // cascade shutdown
173+
)
174+
} else {
175+
pool
176+
}
125177
}
126178
}
127179

@@ -143,12 +195,14 @@ class ForkJoinExecutorConfigurator(config: Config, prerequisites: DispatcherPrer
143195
}
144196

145197
new ForkJoinExecutorServiceFactory(
198+
id,
146199
validate(tf),
147200
ThreadPoolConfig.scaledPoolSize(
148201
config.getInt("parallelism-min"),
149202
config.getDouble("parallelism-factor"),
150203
config.getInt("parallelism-max")),
151204
asyncMode,
152-
config.getInt("maximum-pool-size"))
205+
config.getInt("maximum-pool-size"),
206+
config.getBoolean("virtualize"))
153207
}
154208
}

actor/src/main/scala/org/apache/pekko/dispatch/ThreadPoolBuilder.scala

+12
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,18 @@ final case class MonitorableThreadFactory(
235235
}
236236
}
237237

238+
class MonitorableCarrierThreadFactory(name: String)
239+
extends ForkJoinPool.ForkJoinWorkerThreadFactory {
240+
private val counter = new AtomicLong(0L)
241+
242+
def newThread(pool: ForkJoinPool): ForkJoinWorkerThread = {
243+
val thread = VirtualThreadSupport.CarrierThreadFactory.newThread(pool)
244+
// Name of the threads for the ForkJoinPool are not customizable. Change it here.
245+
thread.setName(name + "-" + "CarrierThread" + "-" + counter.incrementAndGet())
246+
thread
247+
}
248+
}
249+
238250
/**
239251
* As the name says
240252
*/

actor/src/main/scala/org/apache/pekko/dispatch/VirtualThreadSupport.scala

+53-15
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717

1818
package org.apache.pekko.dispatch
1919

20-
import org.apache.pekko.annotation.InternalApi
21-
import org.apache.pekko.util.JavaVersion
20+
import org.apache.pekko
21+
import pekko.annotation.InternalApi
22+
import pekko.util.JavaVersion
2223

2324
import java.lang.invoke.{ MethodHandles, MethodType }
24-
import java.util.concurrent.{ ExecutorService, ForkJoinPool, ThreadFactory }
25+
import java.util.concurrent.{ ExecutorService, ForkJoinPool, ForkJoinWorkerThread, ThreadFactory }
2526
import scala.util.control.NonFatal
2627

2728
@InternalApi
@@ -34,8 +35,26 @@ private[dispatch] object VirtualThreadSupport {
3435
val isSupported: Boolean = JavaVersion.majorVersion >= 21
3536

3637
/**
37-
* Create a virtual thread factory with a executor, the executor will be used as the scheduler of
38-
* virtual thread.
38+
* Create a newThreadPerTaskExecutor with the specified thread factory.
39+
*/
40+
def newThreadPerTaskExecutor(threadFactory: ThreadFactory): ExecutorService = {
41+
require(threadFactory != null, "threadFactory should not be null.")
42+
try {
43+
val executorsClazz = ClassLoader.getSystemClassLoader.loadClass("java.util.concurrent.Executors")
44+
val newThreadPerTaskExecutorMethod = lookup.findStatic(
45+
executorsClazz,
46+
"newThreadPerTaskExecutor",
47+
MethodType.methodType(classOf[ExecutorService], classOf[ThreadFactory]))
48+
newThreadPerTaskExecutorMethod.invoke(threadFactory).asInstanceOf[ExecutorService]
49+
} catch {
50+
case NonFatal(e) =>
51+
// --add-opens java.base/java.lang=ALL-UNNAMED
52+
throw new UnsupportedOperationException("Failed to create newThreadPerTaskExecutor.", e)
53+
}
54+
}
55+
56+
/**
57+
* Create a virtual thread factory with the default Virtual Thread executor.
3958
*/
4059
def newVirtualThreadFactory(prefix: String): ThreadFactory = {
4160
require(isSupported, "Virtual thread is not supported.")
@@ -57,19 +76,38 @@ private[dispatch] object VirtualThreadSupport {
5776
}
5877
}
5978

60-
def newThreadPerTaskExecutor(threadFactory: ThreadFactory): ExecutorService = {
61-
require(threadFactory != null, "threadFactory should not be null.")
79+
/**
80+
* Create a virtual thread factory with the specified executor as the scheduler of virtual thread.
81+
*/
82+
def newVirtualThreadFactory(prefix: String, executor: ExecutorService): ThreadFactory =
6283
try {
63-
val executorsClazz = ClassLoader.getSystemClassLoader.loadClass("java.util.concurrent.Executors")
64-
val newThreadPerTaskExecutorMethod = lookup.findStatic(
65-
executorsClazz,
66-
"newThreadPerTaskExecutor",
67-
MethodType.methodType(classOf[ExecutorService], classOf[ThreadFactory]))
68-
newThreadPerTaskExecutorMethod.invoke(threadFactory).asInstanceOf[ExecutorService]
84+
val builderClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder")
85+
val ofVirtualClass = ClassLoader.getSystemClassLoader.loadClass("java.lang.Thread$Builder$OfVirtual")
86+
val ofVirtualMethod = classOf[Thread].getDeclaredMethod("ofVirtual")
87+
var builder = ofVirtualMethod.invoke(null)
88+
if (executor != null) {
89+
val clazz = builder.getClass
90+
val field = clazz.getDeclaredField("scheduler")
91+
field.setAccessible(true)
92+
field.set(builder, executor)
93+
}
94+
val nameMethod = ofVirtualClass.getDeclaredMethod("name", classOf[String], classOf[Long])
95+
val factoryMethod = builderClass.getDeclaredMethod("factory")
96+
val zero = java.lang.Long.valueOf(0L)
97+
builder = nameMethod.invoke(builder, prefix + "-virtual-thread-", zero)
98+
factoryMethod.invoke(builder).asInstanceOf[ThreadFactory]
6999
} catch {
70100
case NonFatal(e) =>
71101
// --add-opens java.base/java.lang=ALL-UNNAMED
72-
throw new UnsupportedOperationException("Failed to create newThreadPerTaskExecutor.", e)
102+
throw new UnsupportedOperationException("Failed to create virtual thread factory", e)
103+
}
104+
105+
object CarrierThreadFactory extends ForkJoinPool.ForkJoinWorkerThreadFactory {
106+
private val clazz = ClassLoader.getSystemClassLoader.loadClass("jdk.internal.misc.CarrierThread")
107+
// TODO lookup.findClass is only available in Java 9
108+
private val constructor = clazz.getDeclaredConstructor(classOf[ForkJoinPool])
109+
override def newThread(pool: ForkJoinPool): ForkJoinWorkerThread = {
110+
constructor.newInstance(pool).asInstanceOf[ForkJoinWorkerThread]
73111
}
74112
}
75113

@@ -79,7 +117,7 @@ private[dispatch] object VirtualThreadSupport {
79117
def getVirtualThreadDefaultScheduler: ForkJoinPool =
80118
try {
81119
require(isSupported, "Virtual thread is not supported.")
82-
val clazz = Class.forName("java.lang.VirtualThread")
120+
val clazz = ClassLoader.getSystemClassLoader.loadClass("java.lang.VirtualThread")
83121
val fieldName = "DEFAULT_SCHEDULER"
84122
val field = clazz.getDeclaredField(fieldName)
85123
field.setAccessible(true)

actor/src/main/scala/org/apache/pekko/dispatch/VirtualizedExecutorService.scala

-7
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,6 @@ final class VirtualizedExecutorService(
3939
require(vtFactory != null, "Virtual thread factory must not be null")
4040
require(loadMetricsProvider != null, "Load metrics provider must not be null")
4141

42-
def this(prefix: String,
43-
underlying: ExecutorService,
44-
loadMetricsProvider: Executor => Boolean,
45-
cascadeShutdown: Boolean) = {
46-
this(VirtualThreadSupport.newVirtualThreadFactory(prefix), underlying, loadMetricsProvider, cascadeShutdown)
47-
}
48-
4942
private val executor = VirtualThreadSupport.newThreadPerTaskExecutor(vtFactory)
5043

5144
override def atFullThrottle(): Boolean = loadMetricsProvider(this)

docs/src/main/paradox/dispatchers.md

+10
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,16 @@ You can read more about parallelism in the JDK's [ForkJoinPool documentation](ht
4444

4545
When Running on Java 9+, you can use `maximum-pool-size` to set the upper bound on the total number of threads allocated by the ForkJoinPool.
4646

47+
**Experimental**: When Running on Java 21+, you can use `virtualize=on` to enable the virtual threads feature.
48+
When using virtual threads, all virtual threads will use the same `unparker`, so you may want to
49+
increase the number of `jdk.unparker.maxPoolSize`.
50+
51+
#### Requirements:
52+
53+
1. JDK 21+
54+
2. add options to the JVM:
55+
- `--add-opens=java.base/jdk.internal.misc=ALL-UNNAMED`
56+
- `--add-opens=java.base/java.lang=ALL-UNNAMED`
4757
@@@
4858

4959
Another example that uses the "thread-pool-executor":

docs/src/main/paradox/typed/dispatchers.md

+11
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,17 @@ You can read more about parallelism in the JDK's [ForkJoinPool documentation](ht
129129

130130
When Running on Java 9+, you can use `maximum-pool-size` to set the upper bound on the total number of threads allocated by the ForkJoinPool.
131131

132+
**Experimental**: When Running on Java 21+, you can use `virtualize=on` to enable the virtual threads feature.
133+
When using virtual threads, all virtual threads will use the same `unparker`, so you may want to
134+
increase the number of `jdk.unparker.maxPoolSize`.
135+
136+
#### Requirements:
137+
138+
1. JDK 21+
139+
2. add options to the JVM:
140+
- `--add-opens=java.base/jdk.internal.misc=ALL-UNNAMED`
141+
- `--add-opens=java.base/java.lang=ALL-UNNAMED`
142+
132143
@@@
133144

134145
@@@ note

0 commit comments

Comments
 (0)