Skip to content

Commit c71ec91

Browse files
committed
add test
1 parent 19b443b commit c71ec91

File tree

2 files changed

+157
-46
lines changed

2 files changed

+157
-46
lines changed

core/raydp-main/src/main/scala/org/apache/spark/executor/RayDPExecutor.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ class RayDPExecutor(
383383
s"Forwarding fetch to executor $ownerSparkExecutorId " +
384384
s"(ray actor id $ownerRayExecutorId).")
385385
val otherHandle =
386-
Ray.getActor("raydp-executor-" + ownerRayExecutorId).get
386+
Ray.getActor("raydp-executor-" + ownerRayExecutorId).get()
387387
.asInstanceOf[ActorHandle[RayDPExecutor]]
388388
// One-hop forward only: call no-forward variant on the target executor and
389389
// return the Arrow IPC bytes directly.

python/raydp/tests/test_recoverable_forwarding.py

Lines changed: 156 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -15,64 +15,175 @@
1515
# limitations under the License.
1616
#
1717

18-
import platform
19-
18+
import os
2019
import pytest
20+
import pyarrow as pa
2121
from pyspark.storagelevel import StorageLevel
2222
import ray
23+
from ray.cluster_utils import Cluster
24+
from ray.data import from_arrow_refs
2325
import ray.util.client as ray_client
26+
import raydp
27+
28+
try:
29+
# Ray cross-language calls require enabling load_code_from_local.
30+
# This is an internal Ray API; keep it isolated and optional.
31+
from ray._private.worker import global_worker as _ray_global_worker # type: ignore
32+
except Exception: # pragma: no cover
33+
_ray_global_worker = None
34+
35+
@ray.remote(max_retries=-1)
36+
def _fetch_arrow_table_from_executor(
37+
executor_actor_name: str,
38+
rdd_id: int,
39+
partition_id: int,
40+
schema_json: str,
41+
driver_agent_url: str,
42+
) -> pa.Table:
43+
"""Fetch Arrow table bytes from a JVM executor actor and decode to `pyarrow.Table`.
44+
45+
This is a test-local version of RayDP's recoverable fetch task. Keeping it in this test
46+
avoids Ray remote function registration issues when driver/workers import different `raydp`
47+
versions.
48+
"""
49+
if _ray_global_worker is not None:
50+
_ray_global_worker.set_load_code_from_local(True)
2451

25-
from raydp.spark import dataset as spark_dataset
26-
52+
executor_actor = ray.get_actor(executor_actor_name)
53+
ipc_bytes = ray.get(
54+
executor_actor.getRDDPartition.remote(
55+
rdd_id, partition_id, schema_json, driver_agent_url
56+
)
57+
)
58+
reader = pa.ipc.open_stream(pa.BufferReader(ipc_bytes))
59+
table = reader.read_all()
60+
# Match RayDP behavior: strip schema metadata for stability.
61+
table = table.replace_schema_metadata()
62+
return table
2763

28-
if platform.system() == "Darwin":
29-
# Spark-on-Ray recoverable path is unstable on macOS and can crash the raylet.
30-
pytest.skip("Skip recoverable forwarding test on macOS", allow_module_level=True)
3164

3265

33-
@pytest.mark.parametrize("spark_on_ray_2_executors", ["local"], indirect=True)
34-
def test_recoverable_forwarding_via_fetch_task(spark_on_ray_2_executors):
66+
def test_recoverable_forwarding_via_fetch_task(jdk17_extra_spark_configs):
3567
"""Verify JVM-side forwarding in recoverable Spark->Ray conversion.
3668
37-
We deliberately trigger the recoverable fetch task to contact an executor actor that is not
38-
the current owner of the cached Spark block for the chosen partition. The request should still
39-
succeed because the executor refreshes the block owner and forwards the fetch one hop.
69+
This test intentionally calls the recoverable fetch task on the *wrong* Spark executor actor.
70+
It should still succeed because `RayDPExecutor.getRDDPartition` refreshes the block owner and
71+
forwards the fetch one hop.
4072
"""
4173
if ray_client.ray.is_connected():
4274
pytest.skip("Skip forwarding test in Ray client mode")
4375

44-
spark = spark_on_ray_2_executors
45-
46-
# Create enough partitions so that at least two different executors own cached blocks.
47-
df = spark.range(0, 10000, numPartitions=8)
48-
49-
sc = spark.sparkContext
50-
storage_level = sc._getJavaStorageLevel(StorageLevel.MEMORY_AND_DISK)
51-
object_store_writer = sc._jvm.org.apache.spark.sql.raydp.ObjectStoreWriter
52-
53-
info = object_store_writer.prepareRecoverableRDD(df._jdf, storage_level)
54-
rdd_id = info.rddId()
55-
schema_json = info.schemaJson()
56-
driver_agent_url = info.driverAgentUrl()
57-
locations = list(info.locations())
58-
59-
assert locations
60-
unique_execs = sorted(set(locations))
61-
assert len(unique_execs) >= 2, f"Need >=2 executors, got {unique_execs}"
62-
63-
# Pick a partition and intentionally target the *wrong* executor actor.
64-
partition_id = 0
65-
owner_executor_id = locations[partition_id]
66-
wrong_executor_id = next(e for e in unique_execs if e != owner_executor_id)
67-
68-
# Ensure Ray cross-language calls are enabled for the worker side.
69-
spark_dataset._enable_load_code_from_local()
70-
71-
wrong_executor_actor_name = f"raydp-executor-{wrong_executor_id}"
72-
table = ray.get(
73-
spark_dataset._fetch_arrow_table_from_executor.remote(
74-
wrong_executor_actor_name, rdd_id, partition_id, schema_json, driver_agent_url
75-
)
76+
stop_after = os.environ.get("RAYDP_TRACE_STOP_AFTER", "").strip().lower()
77+
fetch_mode = os.environ.get("RAYDP_FETCH_MODE", "task").strip().lower()
78+
cluster = Cluster(
79+
initialize_head=True,
80+
head_node_args={
81+
"num_cpus": 2,
82+
"resources": {"master": 10},
83+
"include_dashboard": True,
84+
"dashboard_port": 0,
85+
},
7686
)
77-
assert table.num_rows > 0
87+
cluster.add_node(num_cpus=4, resources={"spark_executor": 10})
88+
89+
def phase(name: str) -> None:
90+
# Prints are the most reliable breadcrumb if the raylet crashes.
91+
print(f"\n=== PHASE: {name} ===", flush=True)
92+
93+
def should_stop(name: str) -> bool:
94+
return bool(stop_after) and stop_after == name.lower()
95+
96+
spark = None
97+
try:
98+
# Single-node Ray is sufficient to reproduce / bisect the crash.
99+
phase("ray.init")
100+
ray.shutdown()
101+
ray.init(address=cluster.address, include_dashboard=False)
102+
if should_stop("ray.init"):
103+
return
104+
105+
phase("raydp.init_spark")
106+
node_ip = ray.util.get_node_ip_address()
107+
spark = raydp.init_spark(
108+
app_name="test_recoverable_forwarding_via_fetch_task",
109+
num_executors=2,
110+
executor_cores=1,
111+
executor_memory="500M",
112+
configs={
113+
"spark.driver.host": node_ip,
114+
"spark.driver.bindAddress": node_ip,
115+
**jdk17_extra_spark_configs,
116+
},
117+
)
118+
if should_stop("raydp.init_spark"):
119+
return
120+
121+
phase("spark.range.count")
122+
df = spark.range(0, 10000, numPartitions=8)
123+
_ = df.count()
124+
if should_stop("spark.range.count"):
125+
return
126+
127+
phase("prepareRecoverableRDD")
128+
sc = spark.sparkContext
129+
storage_level = sc._getJavaStorageLevel(StorageLevel.MEMORY_AND_DISK)
130+
object_store_writer = sc._jvm.org.apache.spark.sql.raydp.ObjectStoreWriter
131+
info = object_store_writer.prepareRecoverableRDD(df._jdf, storage_level)
132+
rdd_id = info.rddId()
133+
schema_json = info.schemaJson()
134+
driver_agent_url = info.driverAgentUrl()
135+
locations = list(info.locations())
136+
if should_stop("preparerecoverablerdd"):
137+
return
138+
139+
assert locations
140+
unique_execs = sorted(set(locations))
141+
assert len(unique_execs) >= 2, f"Need >=2 executors, got {unique_execs}"
142+
143+
partition_id = 0
144+
owner_executor_id = locations[partition_id]
145+
wrong_executor_id = next(e for e in unique_execs if e != owner_executor_id)
146+
wrong_executor_actor_name = f"raydp-executor-{wrong_executor_id}"
147+
148+
phase("fetch_wrong_executor")
149+
150+
phase("get_wrong_executor_actor")
151+
wrong_executor_actor = ray.get_actor(wrong_executor_actor_name)
152+
if should_stop("get_wrong_executor_actor"):
153+
return
154+
155+
phase("call_fetch_task")
156+
if fetch_mode == "driver":
157+
phase("driver_call_java_actor")
158+
if _ray_global_worker is not None:
159+
_ray_global_worker.set_load_code_from_local(True)
160+
ipc_bytes = ray.get(
161+
wrong_executor_actor.getRDDPartition.remote(
162+
rdd_id, partition_id, schema_json, driver_agent_url
163+
)
164+
)
165+
reader = pa.ipc.open_stream(pa.BufferReader(ipc_bytes))
166+
table = reader.read_all()
167+
table = table.replace_schema_metadata()
168+
else:
169+
phase("task_call_java_actor")
170+
refs: list[ray.ObjectRef] = []
171+
refs.append(
172+
_fetch_arrow_table_from_executor.remote(
173+
wrong_executor_actor_name,
174+
rdd_id,
175+
partition_id,
176+
schema_json,
177+
driver_agent_url,
178+
)
179+
)
180+
table = from_arrow_refs(refs)
181+
assert table.count() > 0
182+
finally:
183+
phase("teardown")
184+
185+
spark.stop()
186+
raydp.stop_spark()
187+
ray.shutdown()
188+
cluster.shutdown()
78189

0 commit comments

Comments
 (0)