|
36 | 36 | import java.nio.file.Paths; |
37 | 37 | import java.util.Arrays; |
38 | 38 | import java.util.Properties; |
| 39 | +import java.util.Map; |
| 40 | +import java.util.HashMap; |
| 41 | +import java.lang.reflect.Method; |
39 | 42 |
|
40 | 43 | import static org.junit.jupiter.api.Assertions.assertEquals; |
41 | 44 | import static org.junit.jupiter.api.Assertions.assertFalse; |
@@ -325,4 +328,78 @@ void testYarnClusterMode_3() throws IOException { |
325 | 328 | } |
326 | 329 | FileUtils.deleteDirectory(localRepoPath.toFile()); |
327 | 330 | } |
| 331 | + |
| 332 | + @Test |
| 333 | + void testDetectSparkScalaVersionResourceCleanup() throws Exception { |
| 334 | + SparkInterpreterLauncher launcher = new SparkInterpreterLauncher(zConf, null); |
| 335 | + |
| 336 | + // Get temp directory before test |
| 337 | + File tempDir = new File(System.getProperty("java.io.tmpdir")); |
| 338 | + File[] filesBeforeTest = tempDir.listFiles((dir, name) -> name.startsWith("zeppelin-spark") && name.endsWith(".out")); |
| 339 | + int tempFilesCountBefore = filesBeforeTest != null ? filesBeforeTest.length : 0; |
| 340 | + |
| 341 | + // Use reflection to access private method |
| 342 | + Method detectSparkScalaVersionMethod = SparkInterpreterLauncher.class.getDeclaredMethod( |
| 343 | + "detectSparkScalaVersion", String.class, Map.class); |
| 344 | + detectSparkScalaVersionMethod.setAccessible(true); |
| 345 | + |
| 346 | + Map<String, String> env = new HashMap<>(); |
| 347 | + |
| 348 | + try { |
| 349 | + // Call the method |
| 350 | + String scalaVersion = (String) detectSparkScalaVersionMethod.invoke(launcher, sparkHome, env); |
| 351 | + |
| 352 | + // Verify we got a valid result |
| 353 | + assertTrue(scalaVersion.equals("2.12") || scalaVersion.equals("2.13"), |
| 354 | + "Expected scala version 2.12 or 2.13 but got: " + scalaVersion); |
| 355 | + |
| 356 | + // Check that no temp files were left behind |
| 357 | + File[] filesAfterTest = tempDir.listFiles((dir, name) -> name.startsWith("zeppelin-spark") && name.endsWith(".out")); |
| 358 | + int tempFilesCountAfter = filesAfterTest != null ? filesAfterTest.length : 0; |
| 359 | + |
| 360 | + assertEquals(tempFilesCountBefore, tempFilesCountAfter, |
| 361 | + "Temporary files were not cleaned up properly"); |
| 362 | + |
| 363 | + } catch (Exception e) { |
| 364 | + // Even if the method fails, temp files should be cleaned up |
| 365 | + File[] filesAfterException = tempDir.listFiles((dir, name) -> name.startsWith("zeppelin-spark") && name.endsWith(".out")); |
| 366 | + int tempFilesCountAfterException = filesAfterException != null ? filesAfterException.length : 0; |
| 367 | + |
| 368 | + assertEquals(tempFilesCountBefore, tempFilesCountAfterException, |
| 369 | + "Temporary files were not cleaned up after exception"); |
| 370 | + |
| 371 | + // Re-throw to fail the test if needed |
| 372 | + throw e; |
| 373 | + } |
| 374 | + } |
| 375 | + |
| 376 | + @Test |
| 377 | + void testDetectSparkScalaVersionMultipleCalls() throws Exception { |
| 378 | + SparkInterpreterLauncher launcher = new SparkInterpreterLauncher(zConf, null); |
| 379 | + |
| 380 | + // Get temp directory |
| 381 | + File tempDir = new File(System.getProperty("java.io.tmpdir")); |
| 382 | + File[] filesBeforeTest = tempDir.listFiles((dir, name) -> name.startsWith("zeppelin-spark") && name.endsWith(".out")); |
| 383 | + int tempFilesCountBefore = filesBeforeTest != null ? filesBeforeTest.length : 0; |
| 384 | + |
| 385 | + // Use reflection to access private method |
| 386 | + Method detectSparkScalaVersionMethod = SparkInterpreterLauncher.class.getDeclaredMethod( |
| 387 | + "detectSparkScalaVersion", String.class, Map.class); |
| 388 | + detectSparkScalaVersionMethod.setAccessible(true); |
| 389 | + |
| 390 | + Map<String, String> env = new HashMap<>(); |
| 391 | + |
| 392 | + // Call the method multiple times to ensure resources are properly cleaned each time |
| 393 | + for (int i = 0; i < 5; i++) { |
| 394 | + String scalaVersion = (String) detectSparkScalaVersionMethod.invoke(launcher, sparkHome, env); |
| 395 | + assertTrue(scalaVersion.equals("2.12") || scalaVersion.equals("2.13")); |
| 396 | + } |
| 397 | + |
| 398 | + // Check that no temp files accumulated |
| 399 | + File[] filesAfterTest = tempDir.listFiles((dir, name) -> name.startsWith("zeppelin-spark") && name.endsWith(".out")); |
| 400 | + int tempFilesCountAfter = filesAfterTest != null ? filesAfterTest.length : 0; |
| 401 | + |
| 402 | + assertEquals(tempFilesCountBefore, tempFilesCountAfter, |
| 403 | + "Temporary files accumulated after multiple calls"); |
| 404 | + } |
328 | 405 | } |
0 commit comments