Skip to content

Commit 5a98b1e

Browse files
authored
remove usage of ray.put(data, _owner) and private ray object ownership manipulation API (#454)
* remove usage of ray.put(data, owner) * use mixin base class * fix test * eliminate usage of deserialize_and_register_object_ref * distributed owner actors * fix test * reimplement recoverable conversion * make data fetch task resource configurable * only support ray 2.37.0 and beyond * fix pyspark internal cache * add test against ray 2.50.0 * remove usage of dashboard_grpc_port * fix read_parquet on ray 2.5x, make tf test work on Apple silicon * implement single owner * rename RayDPBlockStoreActorRegistry to RayDPDataOwner * add test to gate data locality * align everything use from_spark_recoverable
1 parent 3ef2bac commit 5a98b1e

File tree

22 files changed

+436
-526
lines changed

22 files changed

+436
-526
lines changed

.github/workflows/raydp.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ jobs:
3434
os: [ubuntu-latest]
3535
python-version: [3.9, 3.10.14]
3636
spark-version: [3.3.2, 3.4.0, 3.5.0]
37-
ray-version: [2.34.0, 2.40.0]
37+
ray-version: [2.37.0, 2.40.0, 2.50.0]
3838

3939
runs-on: ${{ matrix.os }}
4040

README.md

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -153,28 +153,25 @@ Please refer to [NYC Taxi PyTorch Estimator](./examples/pytorch_nyctaxi.py) and
153153

154154
***Fault Tolerance***
155155

156-
The ray dataset converted from spark dataframe like above is not fault-tolerant. This is because we implement it using `Ray.put` combined with spark `mapPartitions`. Objects created by `Ray.put` is not recoverable in Ray.
156+
RayDP now converts Spark DataFrames to Ray Datasets using a recoverable pipeline by default. This makes the resulting Ray Dataset resilient to Spark executor loss (the Arrow IPC bytes are cached in Spark and fetched via Ray tasks with lineage).
157+
158+
The recoverable conversion is also available directly via `raydp.spark.from_spark_recoverable`, and it persists (caches) the Spark DataFrame. You can provide the storage level through the `storage_level` keyword parameter.
157159

158-
RayDP now supports converting data in a way such that the resulting ray dataset is fault-tolerant. This feature is currently *experimental*. Here is how to use it:
159160
```python
160161
import ray
161162
import raydp
162163

163164
ray.init(address="auto")
164-
# set fault_tolerance_mode to True to enable the feature
165-
# this will connect pyspark driver to ray cluster
166165
spark = raydp.init_spark(app_name="RayDP Example",
167166
num_executors=2,
168167
executor_cores=2,
169-
executor_memory="4GB",
170-
fault_tolerance_mode=True)
171-
# df should be large enough so that result will be put into plasma
168+
executor_memory="4GB")
169+
172170
df = spark.range(100000)
173-
# use this API instead of ray.data.from_spark
174-
ds = raydp.spark.from_spark_recoverable(df)
175-
# ds is now fault-tolerant.
171+
ds = raydp.spark.from_spark_recoverable(df) # fault-tolerant
176172
```
177-
Notice that `from_spark_recoverable` will persist the converted dataframe. You can provide the storage level through keyword parameter `storage_level`. In addition, this feature is not available in ray client mode. If you need to use ray client, please wrap your application in a ray actor, as described in the ray client chapter.
173+
174+
Note: recoverable conversion is not available in Ray client mode. If you need to use Ray client, wrap your application in a Ray actor as described in the Ray client docs.
178175

179176

180177
## Getting Involved

core/raydp-main/src/main/java/org/apache/spark/raydp/RayExecutorUtils.java

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,14 @@
2121
import io.ray.api.ObjectRef;
2222
import io.ray.api.Ray;
2323
import io.ray.api.call.ActorCreator;
24-
import java.util.Map;
25-
import java.util.List;
26-
2724
import io.ray.api.placementgroup.PlacementGroup;
2825
import io.ray.runtime.object.ObjectRefImpl;
26+
import java.util.List;
27+
import java.util.Map;
2928
import org.apache.spark.executor.RayDPExecutor;
3029

3130
public class RayExecutorUtils {
32-
/**
33-
* Convert from mbs -> memory units. The memory units in ray is byte
34-
*/
35-
31+
/** Convert from mbs -> memory units. The memory units in ray is byte. */
3632
private static double toMemoryUnits(int memoryInMB) {
3733
double result = 1.0 * memoryInMB * 1024 * 1024;
3834
return Math.round(result);
@@ -47,14 +43,13 @@ public static ActorHandle<RayDPExecutor> createExecutorActor(
4743
PlacementGroup placementGroup,
4844
int bundleIndex,
4945
List<String> javaOpts) {
50-
ActorCreator<RayDPExecutor> creator = Ray.actor(
51-
RayDPExecutor::new, executorId, appMasterURL);
46+
ActorCreator<RayDPExecutor> creator = Ray.actor(RayDPExecutor::new, executorId, appMasterURL);
5247
creator.setName("raydp-executor-" + executorId);
5348
creator.setJvmOptions(javaOpts);
5449
creator.setResource("CPU", cores);
5550
creator.setResource("memory", toMemoryUnits(memoryInMB));
5651

57-
for (Map.Entry<String, Double> entry: resources.entrySet()) {
52+
for (Map.Entry<String, Double> entry : resources.entrySet()) {
5853
creator.setResource(entry.getKey(), entry.getValue());
5954
}
6055
if (placementGroup != null) {
@@ -72,16 +67,12 @@ public static void setUpExecutor(
7267
String driverUrl,
7368
int cores,
7469
String classPathEntries) {
75-
handler.task(RayDPExecutor::startUp,
76-
appId, driverUrl, cores, classPathEntries).remote();
70+
handler.task(RayDPExecutor::startUp, appId, driverUrl, cores, classPathEntries).remote();
7771
}
7872

7973
public static String[] getBlockLocations(
80-
ActorHandle<RayDPExecutor> handler,
81-
int rddId,
82-
int numPartitions) {
83-
return handler.task(RayDPExecutor::getBlockLocations,
84-
rddId, numPartitions).remote().get();
74+
ActorHandle<RayDPExecutor> handler, int rddId, int numPartitions) {
75+
return handler.task(RayDPExecutor::getBlockLocations, rddId, numPartitions).remote().get();
8576
}
8677

8778
public static ObjectRef<byte[]> getRDDPartition(
@@ -90,14 +81,14 @@ public static ObjectRef<byte[]> getRDDPartition(
9081
int partitionId,
9182
String schema,
9283
String driverAgentUrl) {
93-
return (ObjectRefImpl<byte[]>) handle.task(
94-
RayDPExecutor::getRDDPartition,
95-
rddId, partitionId, schema, driverAgentUrl).remote();
84+
return (ObjectRefImpl<byte[]>)
85+
handle.task(RayDPExecutor::getRDDPartition, rddId, partitionId, schema, driverAgentUrl)
86+
.remote();
9687
}
9788

98-
public static void exitExecutor(
99-
ActorHandle<RayDPExecutor> handle
100-
) {
89+
public static void exitExecutor(ActorHandle<RayDPExecutor> handle) {
10190
handle.task(RayDPExecutor::stop).remote();
10291
}
10392
}
93+
94+

core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala

Lines changed: 73 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -18,156 +18,28 @@
1818
package org.apache.spark.sql.raydp
1919

2020
import com.intel.raydp.shims.SparkShimLoader
21-
import io.ray.api.{ActorHandle, ObjectRef, PyActorHandle, Ray}
21+
import io.ray.api.{ActorHandle, ObjectRef, Ray}
2222
import io.ray.runtime.AbstractRayRuntime
23-
import java.io.ByteArrayOutputStream
2423
import java.util.{List, UUID}
2524
import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue}
2625
import java.util.function.{Function => JFunction}
27-
import org.apache.arrow.vector.VectorSchemaRoot
28-
import org.apache.arrow.vector.ipc.ArrowStreamWriter
2926
import org.apache.arrow.vector.types.pojo.Schema
3027
import scala.collection.JavaConverters._
3128
import scala.collection.mutable
32-
import scala.collection.mutable.ArrayBuffer
3329

3430
import org.apache.spark.{RayDPException, SparkContext}
3531
import org.apache.spark.deploy.raydp._
3632
import org.apache.spark.executor.RayDPExecutor
33+
import org.apache.spark.network.util.JavaUtils
3734
import org.apache.spark.raydp.{RayDPUtils, RayExecutorUtils}
3835
import org.apache.spark.sql.DataFrame
39-
import org.apache.spark.sql.execution.arrow.ArrowWriter
40-
import org.apache.spark.sql.execution.python.BatchIterator
4136
import org.apache.spark.sql.internal.SQLConf
42-
import org.apache.spark.sql.util.ArrowUtils
4337
import org.apache.spark.storage.StorageLevel
44-
import org.apache.spark.util.Utils
45-
46-
/**
47-
* A batch of record that has been wrote into Ray object store.
48-
* @param ownerAddress the owner address of the ray worker
49-
* @param objectId the ObjectId for the stored data
50-
* @param numRecords the number of records for the stored data
51-
*/
52-
case class RecordBatch(
53-
ownerAddress: Array[Byte],
54-
objectId: Array[Byte],
55-
numRecords: Int)
5638

5739
class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable {
5840

5941
val uuid: UUID = ObjectStoreWriter.dfToId.getOrElseUpdate(df, UUID.randomUUID())
6042

61-
def writeToRay(
62-
data: Array[Byte],
63-
numRecords: Int,
64-
queue: ObjectRefHolder.Queue,
65-
ownerName: String): RecordBatch = {
66-
67-
var objectRef: ObjectRef[Array[Byte]] = null
68-
if (ownerName == "") {
69-
objectRef = Ray.put(data)
70-
} else {
71-
var dataOwner: PyActorHandle = Ray.getActor(ownerName).get()
72-
objectRef = Ray.put(data, dataOwner)
73-
}
74-
75-
// add the objectRef to the objectRefHolder to avoid reference GC
76-
queue.add(objectRef)
77-
val objectRefImpl = RayDPUtils.convert(objectRef)
78-
val objectId = objectRefImpl.getId
79-
val runtime = Ray.internal.asInstanceOf[AbstractRayRuntime]
80-
val addressInfo = runtime.getObjectStore.getOwnershipInfo(objectId)
81-
RecordBatch(addressInfo, objectId.getBytes, numRecords)
82-
}
83-
84-
/**
85-
* Save the DataFrame to Ray object store with Apache Arrow format.
86-
*/
87-
def save(useBatch: Boolean, ownerName: String): List[RecordBatch] = {
88-
val conf = df.queryExecution.sparkSession.sessionState.conf
89-
val timeZoneId = conf.getConf(SQLConf.SESSION_LOCAL_TIMEZONE)
90-
var batchSize = conf.getConf(SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH)
91-
if (!useBatch) {
92-
batchSize = 0
93-
}
94-
val schema = df.schema
95-
96-
val objectIds = df.queryExecution.toRdd.mapPartitions{ iter =>
97-
val queue = ObjectRefHolder.getQueue(uuid)
98-
99-
// DO NOT use iter.grouped(). See BatchIterator.
100-
val batchIter = if (batchSize > 0) {
101-
new BatchIterator(iter, batchSize)
102-
} else {
103-
Iterator(iter)
104-
}
105-
106-
val arrowSchema = SparkShimLoader.getSparkShims.toArrowSchema(schema, timeZoneId)
107-
val allocator = ArrowUtils.rootAllocator.newChildAllocator(
108-
s"ray object store writer", 0, Long.MaxValue)
109-
val root = VectorSchemaRoot.create(arrowSchema, allocator)
110-
val results = new ArrayBuffer[RecordBatch]()
111-
112-
val byteOut = new ByteArrayOutputStream()
113-
val arrowWriter = ArrowWriter.create(root)
114-
var numRecords: Int = 0
115-
116-
Utils.tryWithSafeFinally {
117-
while (batchIter.hasNext) {
118-
// reset the state
119-
numRecords = 0
120-
byteOut.reset()
121-
arrowWriter.reset()
122-
123-
// write out the schema meta data
124-
val writer = new ArrowStreamWriter(root, null, byteOut)
125-
writer.start()
126-
127-
// get the next record batch
128-
val nextBatch = batchIter.next()
129-
130-
while (nextBatch.hasNext) {
131-
numRecords += 1
132-
arrowWriter.write(nextBatch.next())
133-
}
134-
135-
// set the write record count
136-
arrowWriter.finish()
137-
// write out the record batch to the underlying out
138-
writer.writeBatch()
139-
140-
// get the wrote ByteArray and save to Ray ObjectStore
141-
val byteArray = byteOut.toByteArray
142-
results += writeToRay(byteArray, numRecords, queue, ownerName)
143-
// end writes footer to the output stream and doesn't clean any resources.
144-
// It could throw exception if the output stream is closed, so it should be
145-
// in the try block.
146-
writer.end()
147-
}
148-
arrowWriter.reset()
149-
byteOut.close()
150-
} {
151-
// If we close root and allocator in TaskCompletionListener, there could be a race
152-
// condition where the writer thread keeps writing to the VectorSchemaRoot while
153-
// it's being closed by the TaskCompletion listener.
154-
// Closing root and allocator here is cleaner because root and allocator is owned
155-
// by the writer thread and is only visible to the writer thread.
156-
//
157-
// If the writer thread is interrupted by TaskCompletionListener, it should either
158-
// (1) in the try block, in which case it will get an InterruptedException when
159-
// performing io, and goes into the finally block or (2) in the finally block,
160-
// in which case it will ignore the interruption and close the resources.
161-
162-
root.close()
163-
allocator.close()
164-
}
165-
166-
results.toIterator
167-
}.collect()
168-
objectIds.toSeq.asJava
169-
}
170-
17143
/**
17244
* For test.
17345
*/
@@ -201,6 +73,15 @@ object ObjectStoreWriter {
20173
}
20274
}
20375

76+
private def parseMemoryBytes(value: String): Double = {
77+
if (value == null || value.isEmpty) {
78+
0.0
79+
} else {
80+
// Spark parser supports both plain numbers (bytes) and strings like "100M", "2g".
81+
JavaUtils.byteStringAsBytes(value).toDouble
82+
}
83+
}
84+
20485
def getAddress(): Array[Byte] = {
20586
if (address == null) {
20687
val objectRef = Ray.put(1)
@@ -218,6 +99,7 @@ object ObjectStoreWriter {
21899
SparkShimLoader.getSparkShims.toArrowSchema(df.schema, timeZoneId)
219100
}
220101

102+
@deprecated
221103
def fromSparkRDD(df: DataFrame, storageLevel: StorageLevel): Array[Array[Byte]] = {
222104
if (!Ray.isInitialized) {
223105
throw new RayDPException(
@@ -267,6 +149,67 @@ object ObjectStoreWriter {
267149
results
268150
}
269151

152+
/**
153+
* Prepare a Spark ArrowBatch RDD for recoverable conversion and return metadata needed by
154+
* Python to build reconstructable Ray Dataset blocks via Ray tasks.
155+
*
156+
* This method:
157+
* - persists and materializes the ArrowBatch RDD in Spark (so partitions can be re-fetched)
158+
* - computes per-partition executor locations (Spark executor IDs)
159+
*
160+
* It does NOT push any data to Ray.
161+
*/
162+
def prepareRecoverableRDD(
163+
df: DataFrame,
164+
storageLevel: StorageLevel): RecoverableRDDInfo = {
165+
if (!Ray.isInitialized) {
166+
throw new RayDPException(
167+
"Not yet connected to Ray! Please set fault_tolerant_mode=True when starting RayDP.")
168+
}
169+
170+
val rdd = df.toArrowBatchRdd
171+
rdd.persist(storageLevel)
172+
rdd.count()
173+
174+
var executorIds = df.sqlContext.sparkContext.getExecutorIds.toArray
175+
val numExecutors = executorIds.length
176+
val appMasterHandle = Ray.getActor(RayAppMaster.ACTOR_NAME)
177+
.get.asInstanceOf[ActorHandle[RayAppMaster]]
178+
val restartedExecutors = RayAppMasterUtils.getRestartedExecutors(appMasterHandle)
179+
if (!restartedExecutors.isEmpty) {
180+
for (i <- 0 until numExecutors) {
181+
if (restartedExecutors.containsKey(executorIds(i))) {
182+
val oldId = restartedExecutors.get(executorIds(i))
183+
executorIds(i) = oldId
184+
}
185+
}
186+
}
187+
188+
val schemaJson = ObjectStoreWriter.toArrowSchema(df).toJson
189+
val numPartitions = rdd.getNumPartitions
190+
191+
val handles = executorIds.map { id =>
192+
Ray.getActor("raydp-executor-" + id)
193+
.get
194+
.asInstanceOf[ActorHandle[RayDPExecutor]]
195+
}
196+
val locations = RayExecutorUtils.getBlockLocations(handles(0), rdd.id, numPartitions)
197+
198+
RecoverableRDDInfo(rdd.id, numPartitions, schemaJson, driverAgentUrl, locations)
199+
}
200+
201+
}
202+
203+
case class RecoverableRDDInfo(
204+
rddId: Int,
205+
numPartitions: Int,
206+
schemaJson: String,
207+
driverAgentUrl: String,
208+
locations: Array[String])
209+
210+
object RecoverableRDDInfo {
211+
// Empty constructor for reflection / Java interop (some tools expect it).
212+
def empty: RecoverableRDDInfo = RecoverableRDDInfo(0, 0, "", "", Array.empty[String])
270213
}
271214

272215
object ObjectRefHolder {

0 commit comments

Comments
 (0)