Skip to content

Commit 03ed810

Browse files
fyrestone刘宝
and
刘宝
authored
[Ray] Ray execution state (#3002)
* Ray execution state * Fix stop pool * Not to fetch chunk meta when tiling HeadOptimizedDataSource * Fix * Fix * Use named actor for Ray task state * Improve coverage * Fix lint Co-authored-by: 刘宝 <[email protected]>
1 parent c2e334f commit 03ed810

File tree

11 files changed

+254
-71
lines changed

11 files changed

+254
-71
lines changed

mars/dataframe/datasource/core.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,7 @@ def _tile_head(cls, op: "HeadOptimizedDataSource"):
5858
# execute first chunk
5959
yield chunks[:1]
6060

61-
ctx = get_context()
62-
chunk_shape = ctx.get_chunks_meta([chunks[0].key], fields=["shape"])[0]["shape"]
63-
61+
chunk_shape = chunks[0].shape
6462
if chunk_shape[0] == op.nrows:
6563
# the first chunk has enough data
6664
tileds[0]._nsplits = tuple((s,) for s in chunk_shape)

mars/deploy/oscar/tests/test_local.py

+25-19
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ async def test_web_session(create_cluster, config):
490490
)
491491

492492

493-
@pytest.mark.parametrize("config", [{"backend": "mars", "incremental_index": True}])
493+
@pytest.mark.parametrize("config", [{"backend": "mars"}])
494494
def test_sync_execute(config):
495495
session = new_session(
496496
backend=config["backend"], n_cpu=2, web=False, use_uvloop=False
@@ -518,25 +518,31 @@ def test_sync_execute(config):
518518
assert d is c
519519
assert abs(session.fetch(d) - raw.sum()) < 0.001
520520

521-
# TODO(fyrestone): Remove this when the Ray backend support incremental index.
522-
if config["incremental_index"]:
523-
with tempfile.TemporaryDirectory() as tempdir:
524-
file_path = os.path.join(tempdir, "test.csv")
525-
pdf = pd.DataFrame(
526-
np.random.RandomState(0).rand(100, 10),
527-
columns=[f"col{i}" for i in range(10)],
528-
)
529-
pdf.to_csv(file_path, index=False)
530-
531-
df = md.read_csv(file_path, chunk_bytes=os.stat(file_path).st_size / 5)
532-
result = df.sum(axis=1).execute().fetch()
533-
expected = pd.read_csv(file_path).sum(axis=1)
534-
pd.testing.assert_series_equal(result, expected)
521+
with tempfile.TemporaryDirectory() as tempdir:
522+
file_path = os.path.join(tempdir, "test.csv")
523+
pdf = pd.DataFrame(
524+
np.random.RandomState(0).rand(100, 10),
525+
columns=[f"col{i}" for i in range(10)],
526+
)
527+
pdf.to_csv(file_path, index=False)
535528

536-
df = md.read_csv(file_path, chunk_bytes=os.stat(file_path).st_size / 5)
537-
result = df.head(10).execute().fetch()
538-
expected = pd.read_csv(file_path).head(10)
539-
pd.testing.assert_frame_equal(result, expected)
529+
df = md.read_csv(
530+
file_path,
531+
chunk_bytes=os.stat(file_path).st_size / 5,
532+
incremental_index=True,
533+
)
534+
result = df.sum(axis=1).execute().fetch()
535+
expected = pd.read_csv(file_path).sum(axis=1)
536+
pd.testing.assert_series_equal(result, expected)
537+
538+
df = md.read_csv(
539+
file_path,
540+
chunk_bytes=os.stat(file_path).st_size / 5,
541+
incremental_index=True,
542+
)
543+
result = df.head(10).execute().fetch()
544+
expected = pd.read_csv(file_path).head(10)
545+
pd.testing.assert_frame_equal(result, expected)
540546

541547
for worker_pool in session._session.client._cluster._worker_pools:
542548
_assert_storage_cleaned(

mars/deploy/oscar/tests/test_ray_dag.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,7 @@ async def test_iterative_tiling(ray_start_regular_shared2, create_cluster):
112112
await test_local.test_iterative_tiling(create_cluster)
113113

114114

115-
# TODO(fyrestone): Support incremental index in ray backend.
116115
@require_ray
117-
@pytest.mark.parametrize("config", [{"backend": "ray", "incremental_index": False}])
116+
@pytest.mark.parametrize("config", [{"backend": "ray"}])
118117
def test_sync_execute(config):
119118
test_local.test_sync_execute(config)

mars/oscar/backends/pool.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,9 @@ async def join(self, timeout: float = None):
457457
async def stop(self):
458458
try:
459459
# clean global router
460-
Router.get_instance().remove_router(self._router)
460+
router = Router.get_instance()
461+
if router is not None:
462+
router.remove_router(self._router)
461463
stop_tasks = []
462464
# stop all servers
463465
stop_tasks.extend([server.stop() for server in self._servers])

mars/services/task/execution/mars/executor.py

+12
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from ..... import oscar as mo
2222
from .....core import ChunkGraph, TileContext
23+
from .....core.context import set_context
2324
from .....core.operand import (
2425
Fetch,
2526
MapReduceOperand,
@@ -33,6 +34,7 @@
3334
from .....resource import Resource
3435
from .....typing import TileableType, BandType
3536
from .....utils import Timer
37+
from ....context import ThreadedServiceContext
3638
from ....cluster.api import ClusterAPI
3739
from ....lifecycle.api import LifecycleAPI
3840
from ....meta.api import MetaAPI
@@ -121,6 +123,7 @@ async def create(
121123
task_id=task.task_id,
122124
cluster_api=cluster_api,
123125
)
126+
await cls._init_context(session_id, address)
124127
return cls(
125128
config,
126129
task,
@@ -142,6 +145,15 @@ async def _get_apis(cls, session_id: str, address: str):
142145
MetaAPI.create(session_id, address),
143146
)
144147

148+
@classmethod
149+
async def _init_context(cls, session_id: str, address: str):
150+
loop = asyncio.get_running_loop()
151+
context = ThreadedServiceContext(
152+
session_id, address, address, address, loop=loop
153+
)
154+
await context.init()
155+
set_context(context)
156+
145157
async def __aenter__(self):
146158
profiling = ProfilingData[self._task.task_id, "general"]
147159
# incref fetch tileables to ensure fetch data not deleted

mars/services/task/execution/ray/context.py

+101-2
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,108 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import functools
16+
import inspect
17+
from typing import Union
18+
19+
from .....core.context import Context
20+
from .....utils import implements, lazy_import
21+
from ....context import ThreadedServiceContext
22+
23+
ray = lazy_import("ray")
24+
25+
26+
class RayRemoteObjectManager:
27+
"""The remote object manager in task state actor."""
28+
29+
def __init__(self):
30+
self._named_remote_objects = {}
31+
32+
def create_remote_object(self, name: str, object_cls, *args, **kwargs):
33+
remote_object = object_cls(*args, **kwargs)
34+
self._named_remote_objects[name] = remote_object
35+
36+
def destroy_remote_object(self, name: str):
37+
self._named_remote_objects.pop(name, None)
38+
39+
async def call_remote_object(self, name: str, attr: str, *args, **kwargs):
40+
remote_object = self._named_remote_objects[name]
41+
meth = getattr(remote_object, attr)
42+
async_meth = self._sync_to_async(meth)
43+
return await async_meth(*args, **kwargs)
44+
45+
@staticmethod
46+
@functools.lru_cache(100)
47+
def _sync_to_async(func):
48+
if inspect.iscoroutinefunction(func):
49+
return func
50+
else:
51+
52+
async def async_wrapper(*args, **kwargs):
53+
return func(*args, **kwargs)
54+
55+
return async_wrapper
56+
57+
58+
class _RayRemoteObjectWrapper:
59+
def __init__(self, task_state_actor: "ray.actor.ActorHandle", name: str):
60+
self._task_state_actor = task_state_actor
61+
self._name = name
62+
63+
def __getattr__(self, attr):
64+
def wrap(*args, **kwargs):
65+
r = self._task_state_actor.call_remote_object.remote(
66+
self._name, attr, *args, **kwargs
67+
)
68+
return ray.get(r)
69+
70+
return wrap
71+
72+
73+
class _RayRemoteObjectContext:
74+
def __init__(
75+
self, actor_name_or_handle: Union[str, "ray.actor.ActorHandle"], *args, **kwargs
76+
):
77+
super().__init__(*args, **kwargs)
78+
self._actor_name_or_handle = actor_name_or_handle
79+
self._task_state_actor = None
80+
81+
def _get_task_state_actor(self) -> "ray.actor.ActorHandle":
82+
if self._task_state_actor is None:
83+
if isinstance(self._actor_name_or_handle, ray.actor.ActorHandle):
84+
self._task_state_actor = self._actor_name_or_handle
85+
else:
86+
self._task_state_actor = ray.get_actor(self._actor_name_or_handle)
87+
return self._task_state_actor
88+
89+
@implements(Context.create_remote_object)
90+
def create_remote_object(self, name: str, object_cls, *args, **kwargs):
91+
task_state_actor = self._get_task_state_actor()
92+
task_state_actor.create_remote_object.remote(name, object_cls, *args, **kwargs)
93+
return _RayRemoteObjectWrapper(task_state_actor, name)
94+
95+
@implements(Context.get_remote_object)
96+
def get_remote_object(self, name: str):
97+
task_state_actor = self._get_task_state_actor()
98+
return _RayRemoteObjectWrapper(task_state_actor, name)
99+
100+
@implements(Context.destroy_remote_object)
101+
def destroy_remote_object(self, name: str):
102+
task_state_actor = self._get_task_state_actor()
103+
task_state_actor.destroy_remote_object.remote(name)
104+
105+
106+
# TODO(fyrestone): Implement more APIs for Ray.
107+
class RayExecutionContext(_RayRemoteObjectContext, ThreadedServiceContext):
108+
"""The context for tiling."""
109+
110+
pass
111+
112+
113+
# TODO(fyrestone): Implement more APIs for Ray.
114+
class RayExecutionWorkerContext(_RayRemoteObjectContext, dict):
115+
"""The context for executing operands."""
15116

16-
# TODO(fyrestone): Should implement the mars.core.context.Context.
17-
class RayExecutionContext(dict):
18117
@staticmethod
19118
def new_custom_log_dir():
20119
return None

0 commit comments

Comments
 (0)