Skip to content

Commit 204dd90

Browse files
author
Xuye (Chris) Qin
authored
[oscar] Add cancel support, optimize error handling, add kill_actor API (#2027)
1 parent ab5283d commit 204dd90

19 files changed

+802
-135
lines changed

mars/lib/gipc.pyx

+27-18
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,30 @@ except ImportError:
442442
ForkProcess = multiprocessing.Process
443443

444444

445+
if not WINDOWS:
446+
try:
447+
from multiprocessing.forking import Popen as mp_Popen
448+
except ImportError: # pragma: no cover
449+
# multiprocessing's internal structure has changed from 3.3 to 3.4.
450+
from multiprocessing.popen_fork import Popen as mp_Popen
451+
452+
# multiprocessing.process.Process.start() and other methods may
453+
# call multiprocessing.process._cleanup(). This and other mp methods
454+
# may call multiprocessing's Popen.poll() which itself invokes
455+
# os.waitpid(). In extreme cases (high-frequent child process
456+
# creation, short-living child processes), this competes with libev's
457+
# SIGCHLD handler and may win, resulting in libev not being able to
458+
# retrieve all SIGCHLD signals corresponding to started children. This
459+
# could make certain _GProcess.join() calls block forever.
460+
# -> Prevent multiprocessing's Popen.poll() from calling
461+
# os.waitpid(). Let libev do the job.
462+
class _GPopen(mp_Popen):
463+
def poll(self, *args, **kwargs):
464+
pass
465+
else:
466+
_GPopen = None
467+
468+
445469
class _GProcess(ForkProcess):
446470
"""
447471
Compatible with the ``multiprocessing.Process`` API.
@@ -488,24 +512,9 @@ class _GProcess(ForkProcess):
488512
# On Windows, cooperative `join()` is realized via polling (non-blocking
489513
# calls to `Process.is_alive()`) and the original `join()` method.
490514
if not WINDOWS:
491-
# multiprocessing.process.Process.start() and other methods may
492-
# call multiprocessing.process._cleanup(). This and other mp methods
493-
# may call multiprocessing's Popen.poll() which itself invokes
494-
# os.waitpid(). In extreme cases (high-frequent child process
495-
# creation, short-living child processes), this competes with libev's
496-
# SIGCHLD handler and may win, resulting in libev not being able to
497-
# retrieve all SIGCHLD signals corresponding to started children. This
498-
# could make certain _GProcess.join() calls block forever.
499-
# -> Prevent multiprocessing's Popen.poll() from calling
500-
# os.waitpid(). Let libev do the job.
501-
try:
502-
from multiprocessing.forking import Popen as mp_Popen
503-
except ImportError: # pragma: no cover
504-
# multiprocessing's internal structure has changed from 3.3 to 3.4.
505-
from multiprocessing.popen_fork import Popen as mp_Popen
506-
# Monkey-patch and forget about the name.
507-
mp_Popen.poll = lambda *a, **b: None
508-
del mp_Popen
515+
@staticmethod
516+
def _Popen(process_obj):
517+
return _GPopen(process_obj)
509518

510519
def start(self):
511520
# Start grabbing SIGCHLD within libev event loop.

mars/oscar/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
del aio
1818

1919
from .api import actor_ref, create_actor, has_actor, destroy_actor, \
20-
Actor, create_actor_pool
20+
kill_actor, Actor, create_actor_pool
2121
from .errors import ActorNotExist, ActorAlreadyExist
2222
from .utils import create_actor_ref
2323

mars/oscar/api.py

+5
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ async def actor_ref(*args, **kwargs):
4040
return await ctx.actor_ref(*args, **kwargs)
4141

4242

43+
async def kill_actor(actor_ref):
44+
ctx = get_context()
45+
return await ctx.kill_actor(actor_ref)
46+
47+
4348
async def create_actor_pool(address: str,
4449
n_process: int = None,
4550
**kwargs):

mars/oscar/backends/mars/allocate_strategy.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ def get_allocated_address(self,
7373
config: ActorPoolConfig,
7474
allocated: allocated_type) -> str:
7575
# allocate to main process
76-
return config.get_external_address(0)
76+
main_process_index = config.get_process_indexes()[0]
77+
return config.get_external_address(main_process_index)
7778

7879

7980
class RandomSubPool(AllocateStrategy):
@@ -86,6 +87,20 @@ def get_allocated_address(self,
8687
return choice(config.get_external_addresses()[1:])
8788

8889

90+
class ProcessIndex(AllocateStrategy):
91+
__slots__ = 'process_index',
92+
93+
def __init__(self, process_index: int):
94+
self.process_index = process_index
95+
96+
@implements(AllocateStrategy.get_allocated_address)
97+
def get_allocated_address(self,
98+
config: ActorPoolConfig,
99+
allocated: allocated_type) -> str:
100+
actual_process_index = config.get_process_indexes()[self.process_index]
101+
return config.get_pool_config(actual_process_index)['external_address'][0]
102+
103+
89104
class RandomLabel(AllocateStrategy):
90105
__slots__ = 'label',
91106

mars/oscar/backends/mars/communication/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
# limitations under the License.
1414

1515
from .base import Client, Server, Channel
16-
from .core import get_client_type, get_server_type, gen_internal_address
16+
from .core import get_client_type, get_server_type, \
17+
gen_internal_address, gen_local_address
1718
from .dummy import DummyClient, DummyServer, DummyChannel
1819
from .socket import SocketClient, SocketServer, UnixSocketClient, \
1920
UnixSocketServer, SocketChannel

mars/oscar/backends/mars/communication/core.py

+4
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,7 @@ def get_server_type(address: str) -> Type[Server]:
6363

6464
def gen_internal_address(process_index: int) -> str:
6565
return f'unixsocket:///{process_index}'
66+
67+
68+
def gen_local_address(process_index: int) -> str:
69+
return f'dummy://{process_index}'

mars/oscar/backends/mars/communication/dummy.py

+18-19
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
import asyncio
1616
import concurrent.futures as futures
1717
from typing import Any, Callable, Coroutine, Dict, Type
18+
from urllib.parse import urlparse
1819

1920
from .....utils import implements, classproperty
2021
from .base import Channel, ChannelType, Server, Client
2122
from .core import register_client, register_server
2223
from .errors import ChannelClosed
2324

24-
DUMMY_ADDRESS = 'dummy://'
25+
DEFAULT_DUMMY_ADDRESS = 'dummy://0'
2526

2627

2728
class DummyChannel(Channel):
@@ -81,7 +82,7 @@ def closed(self) -> bool:
8182
class DummyServer(Server):
8283
__slots__ = '_closed',
8384

84-
_instance = None
85+
_address_to_instances: Dict[str, "DummyServer"] = dict()
8586
scheme = 'dummy'
8687

8788
def __init__(self,
@@ -91,8 +92,8 @@ def __init__(self,
9192
self._closed = asyncio.Event()
9293

9394
@classmethod
94-
def get_instance(cls):
95-
return cls._instance
95+
def get_instance(cls, address: str):
96+
return cls._address_to_instances[address]
9697

9798
@classproperty
9899
@implements(Server.client_type)
@@ -108,23 +109,22 @@ def channel_type(self) -> ChannelType:
108109
@implements(Server.create)
109110
async def create(config: Dict) -> "DummyServer":
110111
config = config.copy()
111-
address = config.pop('address', DUMMY_ADDRESS)
112+
address = config.pop('address', DEFAULT_DUMMY_ADDRESS)
112113
handle_channel = config.pop('handle_channel')
113-
if address != DUMMY_ADDRESS: # pragma: no cover
114+
if urlparse(address).scheme != DummyServer.scheme: # pragma: no cover
114115
raise ValueError(f'Address for DummyServer '
115-
f'should be {DUMMY_ADDRESS}, '
116+
f'should be starts with "dummy://", '
116117
f'got {address}')
117118
if config: # pragma: no cover
118119
raise TypeError(f'Creating DummyServer got unexpected '
119120
f'arguments: {",".join(config)}')
120121

121-
# DummyServer is singleton
122-
if DummyServer._instance is not None:
123-
return DummyServer._instance
124-
125-
server = DummyServer(DUMMY_ADDRESS, handle_channel)
126-
DummyServer._instance = server
127-
return server
122+
try:
123+
return DummyServer.get_instance(address)
124+
except KeyError:
125+
server = DummyServer(address, handle_channel)
126+
DummyServer._address_to_instances[address] = server
127+
return server
128128

129129
@implements(Server.start)
130130
async def start(self):
@@ -151,7 +151,7 @@ async def on_connected(self, *args, **kwargs):
151151
@implements(Server.stop)
152152
async def stop(self):
153153
self._closed.set()
154-
DummyServer._instance = None
154+
del DummyServer._address_to_instances[self.address]
155155

156156
@property
157157
@implements(Server.stopped)
@@ -179,10 +179,10 @@ def __init__(self,
179179
async def connect(dest_address: str,
180180
local_address: str = None,
181181
**kwargs) -> "Client":
182-
if dest_address != DUMMY_ADDRESS: # pragma: no cover
183-
raise ValueError(f'Destination address has to be "dummy://" '
182+
if urlparse(dest_address).scheme != DummyServer.scheme: # pragma: no cover
183+
raise ValueError(f'Destination address should start with "dummy://" '
184184
f'for DummyClient, got {dest_address}')
185-
server = DummyServer.get_instance()
185+
server = DummyServer.get_instance(dest_address)
186186
if server is None: # pragma: no cover
187187
raise RuntimeError('DummyServer needs to be created '
188188
'first before DummyClient')
@@ -200,6 +200,5 @@ async def connect(dest_address: str,
200200
@implements(Client.close)
201201
async def close(self):
202202
await super().close()
203-
DummyClient._instance = None
204203
self._task.cancel()
205204
self._task = None

mars/oscar/backends/mars/communication/socket.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -210,9 +210,12 @@ async def connect(dest_address: str,
210210
return SocketClient(local_address, dest_address, channel)
211211

212212

213+
TEMPDIR = tempfile.gettempdir()
214+
215+
213216
@lru_cache(100)
214217
def _gen_unix_socket_default_path(process_index):
215-
return f'{tempfile.gettempdir()}/mars/' \
218+
return f'{TEMPDIR}/mars/' \
216219
f'{md5(to_binary(str(process_index))).hexdigest()}' # nosec
217220

218221

mars/oscar/backends/mars/communication/tests/test_comm.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import asyncio
1616
import sys
1717
import multiprocessing
18-
from typing import Union
18+
from typing import Union, List, Tuple, Type, Dict
1919

2020
import numpy as np
2121
import pytest
@@ -24,7 +24,7 @@
2424
from mars.oscar.backends.mars.communication import \
2525
SocketChannel, SocketServer, UnixSocketServer, \
2626
DummyChannel, DummyServer, get_client_type, \
27-
SocketClient, UnixSocketClient, DummyClient
27+
SocketClient, UnixSocketClient, DummyClient, Server
2828
from mars.utils import get_next_port
2929

3030

@@ -33,13 +33,13 @@
3333

3434

3535
# server_type, config, con
36-
params = [
36+
params: List[Tuple[Type[Server], Dict, str]] = [
3737
(SocketServer, dict(host='127.0.0.1', port=port), f'127.0.0.1:{port}'),
3838
]
3939
if sys.platform != 'win32':
4040
params.append((UnixSocketServer, dict(process_index='0'), f'unixsocket:///0'))
4141
local_params = params.copy()
42-
local_params.append((DummyServer, dict(), 'dummy://'))
42+
local_params.append((DummyServer, dict(), 'dummy://0'))
4343

4444

4545
@pytest.mark.parametrize(

mars/oscar/backends/mars/config.py

+7
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,13 @@ def get_external_address(self, process_index: int) -> str:
5959
def get_process_indexes(self):
6060
return list(self._conf['pools'])
6161

62+
def get_process_index(self, external_addrress):
63+
for process_index, conf in self._conf['pools'].items():
64+
if external_addrress in conf['external_address']:
65+
return process_index
66+
raise ValueError(f'Cannot get proces_index '
67+
f'for {external_addrress}') # pragma: no cover
68+
6269
def get_external_addresses(self, label=None) -> List[str]:
6370
result = []
6471
for c in self._conf['pools'].values():

0 commit comments

Comments
 (0)