Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import java.util.concurrent.atomic.AtomicBoolean
import scala.reflect.classTag

import com.intel.raydp.shims.SparkShimLoader
import io.ray.api.ActorHandle
import io.ray.api.Ray
import io.ray.runtime.config.RayConfig
import org.apache.arrow.vector.ipc.{ArrowStreamWriter, WriteChannel}
Expand Down Expand Up @@ -273,11 +274,15 @@ class RayDPExecutor(
val blockIds = (0 until numPartitions).map(i =>
BlockId.apply("rdd_" + rddId + "_" + i)
).toArray
val locations = BlockManager.blockIdsToLocations(blockIds, env)
var result = new Array[String](numPartitions)
for ((key, value) <- locations) {
val partitionId = key.name.substring(key.name.lastIndexOf('_') + 1).toInt
result(partitionId) = value(0).substring(value(0).lastIndexOf('_') + 1)
// Prefer structured locations (BlockManagerId.executorId) over parsing a string representation
// of ExecutorCacheTaskLocation. This is more robust across Spark versions.
val locsByBlock = env.blockManager.master.getLocations(blockIds)
val result = new Array[String](numPartitions)
for (i <- 0 until numPartitions) {
val locs = locsByBlock(i)
if (locs != null && locs.nonEmpty) {
result(i) = locs.head.executorId
}
}
result
}
Expand Down Expand Up @@ -305,11 +310,43 @@ class RayDPExecutor(
env.shutdown
}

def getRDDPartition(
/** Refresh the current executor ID that owns a cached Spark block, if any. */
private def getCurrentBlockOwnerExecutorId(blockId: BlockId): Option[String] = {
val env = SparkEnv.get
val locs = env.blockManager.master.getLocations(blockId)
if (locs != null && locs.nonEmpty) Some(locs.head.executorId) else None
}

/**
* Map a (potentially restarted) Spark executor ID to the Ray actor-name executor ID.
*
* When a RayDP executor actor restarts, it keeps its Ray actor name, but Spark may assign a new
* executor ID. RayAppMaster tracks a mapping (new -> old). We must use the old ID to resolve
* the Ray actor by name.
*/
private def resolveRayActorExecutorId(sparkExecutorId: String): String = {
try {
val appMasterHandle =
Ray.getActor(RayAppMaster.ACTOR_NAME).get.asInstanceOf[ActorHandle[RayAppMaster]]
val restartedExecutors = RayAppMasterUtils.getRestartedExecutors(appMasterHandle)
if (restartedExecutors != null && restartedExecutors.containsKey(sparkExecutorId)) {
restartedExecutors.get(sparkExecutorId)
} else {
sparkExecutorId
}
} catch {
case _: Throwable =>
// Best-effort: if we cannot query the app master for any reason, fall back to the given ID.
sparkExecutorId
}
}

private def getRDDPartitionInternal(
rddId: Int,
partitionId: Int,
schemaStr: String,
driverAgentUrl: String): Array[Byte] = {
driverAgentUrl: String,
allowForward: Boolean): Array[Byte] = {
while (!started.get) {
// wait until executor is started
// this might happen if executor restarts
Expand All @@ -330,11 +367,36 @@ class RayDPExecutor(
case Some(blockResult) =>
blockResult.data.asInstanceOf[Iterator[Array[Byte]]]
case None =>
logWarning("The cached block has been lost. Cache it again via driver agent")
logWarning(s"The cached block $blockId has been lost. Cache it again via driver agent")
requestRecacheRDD(rddId, driverAgentUrl)
env.blockManager.get(blockId)(classTag[Array[Byte]]) match {
case Some(blockResult) =>
blockResult.data.asInstanceOf[Iterator[Array[Byte]]]
case None if allowForward =>
// The block may have been (re)cached on a different executor after recache.
val ownerOpt = getCurrentBlockOwnerExecutorId(blockId)
ownerOpt match {
case Some(ownerSparkExecutorId) =>
val ownerRayExecutorId = resolveRayActorExecutorId(ownerSparkExecutorId)
logWarning(
s"Cached block $blockId not found on executor $executorId after recache. " +
s"Forwarding fetch to executor $ownerSparkExecutorId " +
s"(ray actor id $ownerRayExecutorId).")
val otherHandle =
Ray.getActor("raydp-executor-" + ownerRayExecutorId).get()
.asInstanceOf[ActorHandle[RayDPExecutor]]
// One-hop forward only: call no-forward variant on the target executor and
// return the Arrow IPC bytes directly.
return otherHandle
.task(
(e: RayDPExecutor) =>
e.getRDDPartitionNoForward(rddId, partitionId, schemaStr, driverAgentUrl))
.remote()
.get()
case None =>
throw new RayDPException(
s"Still cannot get block $blockId for RDD $rddId after recache!")
}
case None =>
throw new RayDPException("Still cannot get the block after recache!")
}
Expand All @@ -345,8 +407,26 @@ class RayDPExecutor(
iterator.foreach(writeChannel.write)
ArrowStreamWriter.writeEndOfStream(writeChannel, new IpcOption)
val result = byteOut.toByteArray
writeChannel.close
byteOut.close
writeChannel.close()
byteOut.close()
result
}

/** Public entry-point used by cross-language calls. Allows forwarding. */
def getRDDPartition(
rddId: Int,
partitionId: Int,
schemaStr: String,
driverAgentUrl: String): Array[Byte] = {
getRDDPartitionInternal(rddId, partitionId, schemaStr, driverAgentUrl, allowForward = true)
}

/** Internal one-hop target to prevent forward loops. */
def getRDDPartitionNoForward(
rddId: Int,
partitionId: Int,
schemaStr: String,
driverAgentUrl: String): Array[Byte] = {
getRDDPartitionInternal(rddId, partitionId, schemaStr, driverAgentUrl, allowForward = false)
}
}
7 changes: 6 additions & 1 deletion python/raydp/spark/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,12 @@ def _fetch_arrow_table_from_executor(executor_actor_name: str,
executor_actor.getRDDPartition.remote(
rdd_id, partition_id, schema_json, driver_agent_url))
reader = pa.ipc.open_stream(pa.BufferReader(ipc_bytes))
return reader.read_all()
table = reader.read_all()
# Spark's Arrow conversion may attach schema metadata. Ray Data metadata extraction
# can be sensitive to unexpected schema metadata in some Ray/PyArrow combinations.
# Strip schema metadata to make blocks more portable/deterministic.
table = table.replace_schema_metadata()
return table


class RecordPiece:
Expand Down
Loading