Skip to content

Commit b6510ac

Browse files
committed
use mixin base class
1 parent 30de99c commit b6510ac

File tree

4 files changed

+54
-35
lines changed

4 files changed

+54
-35
lines changed

core/raydp-main/src/main/scala/org/apache/spark/sql/raydp/ObjectStoreWriter.scala

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ package org.apache.spark.sql.raydp
1919

2020
import com.intel.raydp.shims.SparkShimLoader
2121
import io.ray.api.{ActorHandle, ObjectRef, Ray}
22+
import io.ray.api.PyActorHandle
23+
import io.ray.api.call.PyActorTaskCaller
24+
import io.ray.api.function.PyActorMethod
2225
import io.ray.runtime.AbstractRayRuntime
2326
import java.io.ByteArrayOutputStream
2427
import java.util.{List, Optional, UUID}
@@ -65,12 +68,33 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable {
6568
ownerName: String): RecordBatch = {
6669

6770
// NOTE: We intentionally do NOT pass an owner argument to Ray.put anymore.
68-
// The default JVM path puts the serialized Arrow batch into Ray's object store
69-
// from the Spark executor JVM process.
7071
//
71-
// Ownership transfer to a long-lived Python actor is implemented on the Python side
72-
// by "adopting" (re-putting) these ObjectRefs inside the target actor.
73-
val objectRef: ObjectRef[Array[Byte]] = Ray.put(data)
72+
// - When ownerName is empty, route the put via the JVM RayAppMaster actor.
73+
// - When ownerName is set to a Python actor name (e.g. RayDPSparkMaster),
74+
// invoke that Python actor's put_data(data) method via Ray cross-language
75+
// calls so that the Python actor becomes the owner of the created object.
76+
val objectRef: ObjectRef[_] =
77+
if (ownerName == "") {
78+
Ray.put(data)
79+
} else {
80+
// Ray.getActor(String) is a raw Java Optional in Ray's Java API.
81+
// If we don't cast it to an explicit reference type here, Scala may infer
82+
// Optional[Nothing] and insert an invalid cast at runtime.
83+
val opt = Ray.getActor(ownerName).asInstanceOf[Optional[AnyRef]]
84+
if (!opt.isPresent) {
85+
throw new RayDPException(s"Actor $ownerName not found when putting dataset block.")
86+
}
87+
val handleAny: AnyRef = opt.get()
88+
if (!handleAny.isInstanceOf[PyActorHandle]) {
89+
throw new RayDPException(
90+
s"Actor $ownerName is not a Python actor; cannot invoke put_data."
91+
)
92+
}
93+
val pyHandle = handleAny.asInstanceOf[PyActorHandle]
94+
val method = PyActorMethod.of("put_data", classOf[AnyRef])
95+
val refOfRef = pyHandle.task(method, data).remote()
96+
refOfRef
97+
}
7498

7599
// add the objectRef to the objectRefHolder to avoid reference GC
76100
queue.add(objectRef)
@@ -171,7 +195,7 @@ class ObjectStoreWriter(@transient val df: DataFrame) extends Serializable {
171195
/**
172196
* For test.
173197
*/
174-
def getRandomRef(): List[Array[Byte]] = {
198+
def getRandomRef(): List[_] = {
175199

176200
df.queryExecution.toRdd.mapPartitions { _ =>
177201
Iterator(ObjectRefHolder.getRandom(uuid))
@@ -270,7 +294,7 @@ object ObjectStoreWriter {
270294
}
271295

272296
object ObjectRefHolder {
273-
type Queue = ConcurrentLinkedQueue[ObjectRef[Array[Byte]]]
297+
type Queue = ConcurrentLinkedQueue[ObjectRef[_]]
274298
private val dfToQueue = new ConcurrentHashMap[UUID, Queue]()
275299

276300
def getQueue(df: UUID): Queue = {
@@ -295,7 +319,7 @@ object ObjectRefHolder {
295319
queue.size()
296320
}
297321

298-
def getRandom(df: UUID): Array[Byte] = {
322+
def getRandom(df: UUID): Any = {
299323
val queue = checkQueueExists(df)
300324
val ref = RayDPUtils.convert(queue.peek())
301325
ref.get()

python/raydp/spark/dataset.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,7 @@ def get_raydp_master_owner(spark: Optional[SparkSession] = None) -> PartitionObj
103103
def raydp_master_set_reference_as_state(
104104
raydp_master_actor: ray.actor.ActorHandle,
105105
objects: List[ObjectRef]) -> ObjectRef:
106-
# Adopt objects in the Python master actor so it becomes the owner of the
107-
# dataset blocks without using Ray.put `_owner`.
108-
return raydp_master_actor.adopt_objects.remote(uuid.uuid4(), objects)
106+
return raydp_master_actor.add_objects.remote(uuid.uuid4(), objects)
109107

110108
return PartitionObjectsOwner(
111109
obj_holder_name,
@@ -143,10 +141,7 @@ def _save_spark_df_to_object_store(df: sql.DataFrame, use_batch: bool = True,
143141

144142
if owner is not None:
145143
actor_owner = ray.get_actor(actor_owner_name)
146-
adopted = ray.get(owner.set_reference_as_state(actor_owner, blocks))
147-
# If the owner callback returns a new list of refs (adoption), use it.
148-
if adopted is not None:
149-
blocks = adopted
144+
ray.get(owner.set_reference_as_state(actor_owner, blocks))
150145

151146
return blocks, block_sizes
152147

python/raydp/spark/ray_cluster_master.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import ray
3030
from py4j.java_gateway import JavaGateway, GatewayParameters
3131
from raydp import versions
32+
import pyarrow as pa
3233

3334
logger = logging.getLogger(__name__)
3435

@@ -48,10 +49,25 @@
4849
SPARK_LOG4J_CONFIG_FILE_NAME = "spark.log4j.config.file.name"
4950
RAY_LOG4J_CONFIG_FILE_NAME = "spark.ray.log4j.config.file.name"
5051

52+
class RayDPObjectOwnerMixin:
53+
"""Mixin for Python actors that can be used as dataset block owners.
54+
55+
The JVM side can invoke the actor method `put_data` via Ray's cross-language
56+
actor call support so that this Python actor becomes the owner of the created
57+
objects, without using Ray's experimental `ray.put(_owner=...)` API.
58+
"""
59+
60+
def put_data(self, data) -> "pa.Table":
61+
"""Put one serialized Arrow batch into the Ray object store."""
62+
# data is Arrow IPC stream bytes written by ArrowStreamWriter
63+
reader = pa.ipc.open_stream(pa.BufferReader(data))
64+
table = reader.read_all()
65+
return table
66+
5167

5268

5369
@ray.remote
54-
class RayDPSparkMaster():
70+
class RayDPSparkMaster(RayDPObjectOwnerMixin):
5571
def __init__(self, configs):
5672
self._gateway = None
5773
self._app_master_java_bridge = None
@@ -224,18 +240,6 @@ def get_spark_home(self) -> str:
224240
def add_objects(self, timestamp, objects):
225241
self._objects[timestamp] = objects
226242

227-
def adopt_objects(self, timestamp, objects):
228-
"""Adopt objects by re-putting them inside this actor.
229-
230-
This makes this actor the owner of the newly created objects without
231-
using the Ray.put `_owner` argument.
232-
233-
Returns the new ObjectRefs.
234-
"""
235-
new_objects = [ray.put(ray.get(obj)) for obj in objects]
236-
self._objects[timestamp] = new_objects
237-
return new_objects
238-
239243
def get_object(self, timestamp, idx):
240244
return self._objects[timestamp][idx]
241245

python/raydp/tests/test_data_owner_transfer.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from raydp.spark import PartitionObjectsOwner
1313
from pyspark.sql import SparkSession
1414
from raydp.spark import get_raydp_master_owner
15+
from raydp.spark.ray_cluster_master import RayDPObjectOwnerMixin
1516

1617

1718
def gen_test_data(spark_session: SparkSession):
@@ -145,7 +146,7 @@ def test_custom_ownership_transfer_custom_actor(ray_cluster, jdk17_extra_spark_c
145146
"""
146147

147148
@ray.remote
148-
class CustomActor:
149+
class CustomActor(RayDPObjectOwnerMixin):
149150
objects: Any
150151

151152
def wake(self):
@@ -154,11 +155,6 @@ def wake(self):
154155
def set_objects(self, objects):
155156
self.objects = objects
156157

157-
def adopt_objects(self, objects):
158-
# Re-put inside this actor so this actor becomes the owner of the new objects.
159-
self.objects = [ray.put(ray.get(o)) for o in objects]
160-
return self.objects
161-
162158
if ray_client.ray.is_connected():
163159
pytest.skip("Skip this test if using ray client")
164160

@@ -190,7 +186,7 @@ def adopt_objects(self, objects):
190186
# and transfer data ownership to dedicated Object Holder (Singleton)
191187
ds = spark_dataframe_to_ray_dataset(df_train, parallelism=4, owner=PartitionObjectsOwner(
192188
owner_actor_name,
193-
lambda actor, objects: actor.adopt_objects.remote(objects)))
189+
lambda actor, objects: actor.set_objects.remote(objects)))
194190

195191
# display data
196192
ds.show(5)

0 commit comments

Comments
 (0)