Skip to content

Commit e36d435

Browse files
committed
[ZEPPELIN-6256] Fix resource leaks in SparkInterpreterLauncher.detectSparkScalaVersion
1 parent 9556b38 commit e36d435

2 files changed

Lines changed: 103 additions & 14 deletions

File tree

zeppelin-zengine/src/main/java/org/apache/zeppelin/interpreter/launcher/SparkInterpreterLauncher.java

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -272,22 +272,34 @@ private String detectSparkScalaVersion(String sparkHome, Map<String, String> env
272272
builder.environment().putAll(env);
273273
File processOutputFile = File.createTempFile("zeppelin-spark", ".out");
274274
builder.redirectError(processOutputFile);
275-
Process process = builder.start();
276-
process.waitFor();
277-
String processOutput = IOUtils.toString(new FileInputStream(processOutputFile), StandardCharsets.UTF_8);
278-
Pattern pattern = Pattern.compile(".*Using Scala version (.*),.*");
279-
Matcher matcher = pattern.matcher(processOutput);
280-
if (matcher.find()) {
281-
String scalaVersion = matcher.group(1);
282-
if (scalaVersion.startsWith("2.12")) {
283-
return "2.12";
284-
} else if (scalaVersion.startsWith("2.13")) {
285-
return "2.13";
275+
276+
try {
277+
Process process = builder.start();
278+
process.waitFor();
279+
280+
String processOutput;
281+
try (FileInputStream in = new FileInputStream(processOutputFile)) {
282+
processOutput = IOUtils.toString(in, StandardCharsets.UTF_8);
283+
}
284+
285+
Pattern pattern = Pattern.compile(".*Using Scala version (.*),.*");
286+
Matcher matcher = pattern.matcher(processOutput);
287+
if (matcher.find()) {
288+
String scalaVersion = matcher.group(1);
289+
if (scalaVersion.startsWith("2.12")) {
290+
return "2.12";
291+
} else if (scalaVersion.startsWith("2.13")) {
292+
return "2.13";
293+
} else {
294+
throw new Exception("Unsupported scala version: " + scalaVersion);
295+
}
286296
} else {
287-
throw new Exception("Unsupported scala version: " + scalaVersion);
297+
return detectSparkScalaVersionByReplClass(sparkHome);
298+
}
299+
} finally {
300+
if (!processOutputFile.delete() && processOutputFile.exists()) {
301+
LOGGER.warn("Failed to delete temporary file: {}", processOutputFile.getAbsolutePath());
288302
}
289-
} else {
290-
return detectSparkScalaVersionByReplClass(sparkHome);
291303
}
292304
}
293305

zeppelin-zengine/src/test/java/org/apache/zeppelin/interpreter/launcher/SparkInterpreterLauncherTest.java

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@
3636
import java.nio.file.Paths;
3737
import java.util.Arrays;
3838
import java.util.Properties;
39+
import java.util.Map;
40+
import java.util.HashMap;
41+
import java.lang.reflect.Method;
3942

4043
import static org.junit.jupiter.api.Assertions.assertEquals;
4144
import static org.junit.jupiter.api.Assertions.assertFalse;
@@ -325,4 +328,78 @@ void testYarnClusterMode_3() throws IOException {
325328
}
326329
FileUtils.deleteDirectory(localRepoPath.toFile());
327330
}
331+
332+
@Test
333+
void testDetectSparkScalaVersionResourceCleanup() throws Exception {
334+
SparkInterpreterLauncher launcher = new SparkInterpreterLauncher(zConf, null);
335+
336+
// Get temp directory before test
337+
File tempDir = new File(System.getProperty("java.io.tmpdir"));
338+
File[] filesBeforeTest = tempDir.listFiles((dir, name) -> name.startsWith("zeppelin-spark") && name.endsWith(".out"));
339+
int tempFilesCountBefore = filesBeforeTest != null ? filesBeforeTest.length : 0;
340+
341+
// Use reflection to access private method
342+
Method detectSparkScalaVersionMethod = SparkInterpreterLauncher.class.getDeclaredMethod(
343+
"detectSparkScalaVersion", String.class, Map.class);
344+
detectSparkScalaVersionMethod.setAccessible(true);
345+
346+
Map<String, String> env = new HashMap<>();
347+
348+
try {
349+
// Call the method
350+
String scalaVersion = (String) detectSparkScalaVersionMethod.invoke(launcher, sparkHome, env);
351+
352+
// Verify we got a valid result
353+
assertTrue(scalaVersion.equals("2.12") || scalaVersion.equals("2.13"),
354+
"Expected scala version 2.12 or 2.13 but got: " + scalaVersion);
355+
356+
// Check that no temp files were left behind
357+
File[] filesAfterTest = tempDir.listFiles((dir, name) -> name.startsWith("zeppelin-spark") && name.endsWith(".out"));
358+
int tempFilesCountAfter = filesAfterTest != null ? filesAfterTest.length : 0;
359+
360+
assertEquals(tempFilesCountBefore, tempFilesCountAfter,
361+
"Temporary files were not cleaned up properly");
362+
363+
} catch (Exception e) {
364+
// Even if the method fails, temp files should be cleaned up
365+
File[] filesAfterException = tempDir.listFiles((dir, name) -> name.startsWith("zeppelin-spark") && name.endsWith(".out"));
366+
int tempFilesCountAfterException = filesAfterException != null ? filesAfterException.length : 0;
367+
368+
assertEquals(tempFilesCountBefore, tempFilesCountAfterException,
369+
"Temporary files were not cleaned up after exception");
370+
371+
// Re-throw to fail the test if needed
372+
throw e;
373+
}
374+
}
375+
376+
@Test
377+
void testDetectSparkScalaVersionMultipleCalls() throws Exception {
378+
SparkInterpreterLauncher launcher = new SparkInterpreterLauncher(zConf, null);
379+
380+
// Get temp directory
381+
File tempDir = new File(System.getProperty("java.io.tmpdir"));
382+
File[] filesBeforeTest = tempDir.listFiles((dir, name) -> name.startsWith("zeppelin-spark") && name.endsWith(".out"));
383+
int tempFilesCountBefore = filesBeforeTest != null ? filesBeforeTest.length : 0;
384+
385+
// Use reflection to access private method
386+
Method detectSparkScalaVersionMethod = SparkInterpreterLauncher.class.getDeclaredMethod(
387+
"detectSparkScalaVersion", String.class, Map.class);
388+
detectSparkScalaVersionMethod.setAccessible(true);
389+
390+
Map<String, String> env = new HashMap<>();
391+
392+
// Call the method multiple times to ensure resources are properly cleaned each time
393+
for (int i = 0; i < 5; i++) {
394+
String scalaVersion = (String) detectSparkScalaVersionMethod.invoke(launcher, sparkHome, env);
395+
assertTrue(scalaVersion.equals("2.12") || scalaVersion.equals("2.13"));
396+
}
397+
398+
// Check that no temp files accumulated
399+
File[] filesAfterTest = tempDir.listFiles((dir, name) -> name.startsWith("zeppelin-spark") && name.endsWith(".out"));
400+
int tempFilesCountAfter = filesAfterTest != null ? filesAfterTest.length : 0;
401+
402+
assertEquals(tempFilesCountBefore, tempFilesCountAfter,
403+
"Temporary files accumulated after multiple calls");
404+
}
328405
}

0 commit comments

Comments
 (0)