|
17 | 17 |
|
18 | 18 | package org.apache.kyuubi.util |
19 | 19 |
|
20 | | -import java.util.concurrent.TimeUnit |
| 20 | +import java.lang.Thread.sleep |
| 21 | +import java.util.concurrent._ |
| 22 | + |
| 23 | +import scala.collection.concurrent.TrieMap |
| 24 | + |
| 25 | +import org.scalatest.time.{Millis, Seconds, Span} |
21 | 26 |
|
22 | 27 | import org.apache.kyuubi.KyuubiFunSuite |
23 | 28 |
|
24 | 29 | class ThreadUtilsSuite extends KyuubiFunSuite { |
25 | | - |
| 30 | + // Configure Eventually patience for retries/waits |
| 31 | + implicit override val patienceConfig: PatienceConfig = PatienceConfig( |
| 32 | + timeout = scaled(Span(5, Seconds)), |
| 33 | + interval = scaled(Span(10, Millis))) |
26 | 34 | test("New daemon single thread scheduled executor for shutdown") { |
27 | 35 | val service = ThreadUtils.newDaemonSingleThreadScheduledExecutor("ThreadUtilsTest") |
28 | 36 | @volatile var threadName = "" |
@@ -61,4 +69,181 @@ class ThreadUtilsSuite extends KyuubiFunSuite { |
61 | 69 | service.awaitTermination(10, TimeUnit.SECONDS) |
62 | 70 | assert(threadName startsWith "") |
63 | 71 | } |
| 72 | + |
| 73 | + // Helper function for cleanup |
| 74 | + private def shutdownAndAwaitTermination( |
| 75 | + service: ExecutorService, |
| 76 | + timeoutSeconds: Long = 5): Unit = { |
| 77 | + service.shutdown() // Disable new tasks from being submitted |
| 78 | + try { |
| 79 | + // Wait a while for existing tasks to terminate |
| 80 | + if (!service.awaitTermination(timeoutSeconds / 2, TimeUnit.SECONDS)) { |
| 81 | + service.shutdownNow() // Cancel currently executing tasks |
| 82 | + // Wait a while for tasks to respond to being cancelled |
| 83 | + if (!service.awaitTermination(timeoutSeconds / 2, TimeUnit.SECONDS)) { |
| 84 | + throw new IllegalStateException( |
| 85 | + s"Thread pool did not terminate within $timeoutSeconds seconds") |
| 86 | + } |
| 87 | + } |
| 88 | + } catch { |
| 89 | + case _: InterruptedException => |
| 90 | + // (Re-)Cancel if current thread also interrupted |
| 91 | + service.shutdownNow() |
| 92 | + // Preserve interrupt status |
| 93 | + Thread.currentThread().interrupt() |
| 94 | + } |
| 95 | + } |
| 96 | + test("newDaemonScheduledThreadPool - thread naming and daemon status") { |
| 97 | + val poolSize = 2 |
| 98 | + val threadNamePrefix = "test-pool-thread" |
| 99 | + val executor = ThreadUtils.newDaemonScheduledThreadPool(poolSize, threadNamePrefix) |
| 100 | + val latch = new CountDownLatch(poolSize) |
| 101 | + val threadNames = TrieMap[String, Boolean]() // Thread-safe map |
| 102 | + try { |
| 103 | + for (_ <- 0 until poolSize) { |
| 104 | + executor.submit(new Runnable { |
| 105 | + override def run(): Unit = { |
| 106 | + val currentThread = Thread.currentThread() |
| 107 | + threadNames.put(currentThread.getName, currentThread.isDaemon) |
| 108 | + latch.countDown() |
| 109 | + } |
| 110 | + }) |
| 111 | + } |
| 112 | + // Wait for tasks to complete |
| 113 | + assert(latch.await(5, TimeUnit.SECONDS), "Tasks did not complete in time") |
| 114 | + // Verify thread names and daemon status |
| 115 | + assert(threadNames.size === poolSize) |
| 116 | + threadNames.foreach { case (name, isDaemon) => |
| 117 | + assert( |
| 118 | + name.startsWith(threadNamePrefix), |
| 119 | + s"Thread name '$name' should start with '$threadNamePrefix'") |
| 120 | + assert(isDaemon, s"Thread '$name' should be a daemon thread") |
| 121 | + } |
| 122 | + } finally { |
| 123 | + shutdownAndAwaitTermination(executor) |
| 124 | + } |
| 125 | + } |
| 126 | + test("newDaemonScheduledThreadPool - schedule and execute tasks") { |
| 127 | + val executor = ThreadUtils.newDaemonScheduledThreadPool(1, "test-schedule") |
| 128 | + val taskRan = new java.util.concurrent.atomic.AtomicBoolean(false) |
| 129 | + val latch = new CountDownLatch(1) |
| 130 | + try { |
| 131 | + val future: ScheduledFuture[_] = executor.schedule( |
| 132 | + new Runnable { |
| 133 | + override def run(): Unit = { |
| 134 | + taskRan.set(true) |
| 135 | + latch.countDown() |
| 136 | + } |
| 137 | + }, |
| 138 | + 50, |
| 139 | + TimeUnit.MILLISECONDS |
| 140 | + ) // Schedule with a small delay |
| 141 | + // Wait for the task to execute |
| 142 | + assert(latch.await(2, TimeUnit.SECONDS), "Scheduled task did not run in time") |
| 143 | + assert(taskRan.get(), "Scheduled task flag should be true") |
| 144 | + assert(future.isDone, "Future should be done after task completion") |
| 145 | + } finally { |
| 146 | + shutdownAndAwaitTermination(executor) |
| 147 | + } |
| 148 | + } |
| 149 | + test("newDaemonScheduledThreadPool - removeOnCancelPolicy works for scheduled tasks") { |
| 150 | + // We need the specific ScheduledThreadPoolExecutor type to access the queue |
| 151 | + val executor = ThreadUtils.newDaemonScheduledThreadPool(1, "test-cancel") |
| 152 | + .asInstanceOf[ScheduledThreadPoolExecutor] |
| 153 | + val taskRan = new java.util.concurrent.atomic.AtomicBoolean(false) |
| 154 | + try { |
| 155 | + // Schedule a task far enough in the future that we can cancel it |
| 156 | + val future: ScheduledFuture[_] = executor.schedule( |
| 157 | + new Runnable { |
| 158 | + override def run(): Unit = { |
| 159 | + taskRan.set(true) |
| 160 | + } |
| 161 | + }, |
| 162 | + 5, |
| 163 | + TimeUnit.SECONDS |
| 164 | + ) // Long delay |
| 165 | + // Verify the task is in the queue initially |
| 166 | + eventually { |
| 167 | + assert(executor.getQueue.size() === 1, "Task should be in the queue initially") |
| 168 | + } |
| 169 | + // Cancel the task |
| 170 | + val cancelled = future.cancel(false) // false = don't interrupt if running (it shouldn't be) |
| 171 | + assert(cancelled, "Future.cancel() should return true") |
| 172 | + // Verify the task is removed from the queue due to the policy |
| 173 | + eventually { |
| 174 | + assert( |
| 175 | + executor.getQueue.isEmpty, |
| 176 | + "Task should be removed from the queue after cancellation") |
| 177 | + } |
| 178 | + // Wait a bit and verify the task never ran |
| 179 | + Thread.sleep(100) // Give some time just in case |
| 180 | + assert(!taskRan.get(), "Cancelled task should not have run") |
| 181 | + } finally { |
| 182 | + shutdownAndAwaitTermination(executor) |
| 183 | + // Final check after shutdown |
| 184 | + assert(executor.getQueue.isEmpty, "Queue should be empty after shutdown") |
| 185 | + } |
| 186 | + } |
| 187 | + test("newDaemonScheduledThreadPool - shutdown rejects new tasks") { |
| 188 | + val executor = ThreadUtils.newDaemonScheduledThreadPool(1, "test-shutdown") |
| 189 | + try { |
| 190 | + // Submit one task to ensure the pool is active |
| 191 | + val future = executor.submit(new Runnable { override def run(): Unit = Thread.sleep(50) }) |
| 192 | + future.get(1, TimeUnit.SECONDS) // Wait for it to finish |
| 193 | + executor.shutdown() |
| 194 | + assert(executor.isShutdown, "Executor should be shutdown") |
| 195 | + // Try submitting after shutdown |
| 196 | + assertThrows[RejectedExecutionException] { |
| 197 | + executor.submit(new Runnable { override def run(): Unit = sleep(5) }) |
| 198 | + } |
| 199 | + assertThrows[RejectedExecutionException] { |
| 200 | + executor.schedule( |
| 201 | + new Runnable { override def run(): Unit = sleep(5) }, |
| 202 | + 10, |
| 203 | + TimeUnit.MILLISECONDS) |
| 204 | + } |
| 205 | + } finally { |
| 206 | + // Ensure termination even if already shut down |
| 207 | + shutdownAndAwaitTermination(executor) |
| 208 | + assert(executor.isTerminated, "Executor should be terminated") |
| 209 | + } |
| 210 | + } |
| 211 | + test("newDaemonScheduledThreadPool - concurrent execution with poolSize > 1") { |
| 212 | + val poolSize = 3 |
| 213 | + val taskCount = 5 |
| 214 | + val executor = ThreadUtils.newDaemonScheduledThreadPool(poolSize, "test-concurrent") |
| 215 | + val latch = new CountDownLatch(taskCount) |
| 216 | + val runningThreads = TrieMap[String, Long]() // Thread Name -> Start Time |
| 217 | + val executionTimes = TrieMap[Int, Long]() // Task Index -> Completion Time |
| 218 | + try { |
| 219 | + val startTime = System.nanoTime() |
| 220 | + for (i <- 0 until taskCount) { |
| 221 | + executor.submit(new Runnable { |
| 222 | + override def run(): Unit = { |
| 223 | + val threadName = Thread.currentThread().getName |
| 224 | + runningThreads.put(threadName, System.nanoTime()) |
| 225 | + try { |
| 226 | + // Simulate work |
| 227 | + Thread.sleep(100) |
| 228 | + } finally { |
| 229 | + executionTimes.put(i, System.nanoTime() - startTime) |
| 230 | + latch.countDown() |
| 231 | + } |
| 232 | + } |
| 233 | + }) |
| 234 | + } |
| 235 | + // Wait for all tasks to complete |
| 236 | + assert(latch.await(5, TimeUnit.SECONDS), s"All $taskCount tasks did not complete in time") |
| 237 | + // Verify that multiple threads were used (likely up to poolSize) |
| 238 | + assert( |
| 239 | + runningThreads.size > 1, |
| 240 | + s"Expected more than 1 thread to be used, but found ${runningThreads.size}") |
| 241 | + assert( |
| 242 | + runningThreads.size <= poolSize, |
| 243 | + s"Used ${runningThreads.size} threads, which should not exceed poolSize $poolSize") |
| 244 | + } finally { |
| 245 | + shutdownAndAwaitTermination(executor) |
| 246 | + } |
| 247 | + } |
| 248 | + |
64 | 249 | } |
0 commit comments