Skip to content

Commit 8e16e5d

Browse files
committed
more robust executor id parse
1 parent e073c14 commit 8e16e5d

File tree

1 file changed

+13
-18
lines changed

1 file changed

+13
-18
lines changed

core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -274,11 +274,15 @@ class RayDPExecutor(
274274
val blockIds = (0 until numPartitions).map(i =>
275275
BlockId.apply("rdd_" + rddId + "_" + i)
276276
).toArray
277-
val locations = BlockManager.blockIdsToLocations(blockIds, env)
278-
var result = new Array[String](numPartitions)
279-
for ((key, value) <- locations) {
280-
val partitionId = key.name.substring(key.name.lastIndexOf('_') + 1).toInt
281-
result(partitionId) = value(0).substring(value(0).lastIndexOf('_') + 1)
277+
// Prefer structured locations (BlockManagerId.executorId) over parsing a string representation
278+
// of ExecutorCacheTaskLocation. This is more robust across Spark versions.
279+
val locsByBlock = env.blockManager.master.getLocations(blockIds)
280+
val result = new Array[String](numPartitions)
281+
for (i <- 0 until numPartitions) {
282+
val locs = locsByBlock(i)
283+
if (locs != null && locs.nonEmpty) {
284+
result(i) = locs.head.executorId
285+
}
282286
}
283287
result
284288
}
@@ -306,20 +310,11 @@ class RayDPExecutor(
306310
env.shutdown
307311
}
308312

309-
private def parseExecutorIdFromLocation(loc: String): String = {
310-
loc.substring(loc.lastIndexOf('_') + 1)
311-
}
312-
313313
/** Refresh the current executor ID that owns a cached Spark block, if any. */
314314
private def getCurrentBlockOwnerExecutorId(blockId: BlockId): Option[String] = {
315315
val env = SparkEnv.get
316-
val locations = BlockManager.blockIdsToLocations(Array(blockId), env)
317-
val locs = locations.getOrElse(blockId, Seq.empty[String])
318-
if (locs.nonEmpty) {
319-
Some(parseExecutorIdFromLocation(locs.head))
320-
} else {
321-
None
322-
}
316+
val locs = env.blockManager.master.getLocations(blockId)
317+
if (locs != null && locs.nonEmpty) Some(locs.head.executorId) else None
323318
}
324319

325320
/**
@@ -411,8 +406,8 @@ class RayDPExecutor(
411406
iterator.foreach(writeChannel.write)
412407
ArrowStreamWriter.writeEndOfStream(writeChannel, new IpcOption)
413408
val result = byteOut.toByteArray
414-
writeChannel.close
415-
byteOut.close
409+
writeChannel.close()
410+
byteOut.close()
416411
result
417412
}
418413

0 commit comments

Comments
 (0)