Skip to content

Commit ea2b643

Browse files
authored
Better support ProcessPoolExecutors (#5063)
Previously we sent non-serializable state, like locks and the worker itself into the function that was sent to the executor Now we only do this if we think tha the executor is a ThreadPoolExecutor * Remove Worker.executor_submit Fixes #3938 This is old and was a workaround for a Tornado issue * remove gen.coroutine from worker.py We're generally trying to clean this up from the codebase * add Worker.active_keys
1 parent 4734833 commit ea2b643

File tree

2 files changed

+99
-73
lines changed

2 files changed

+99
-73
lines changed

distributed/tests/test_worker.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import sys
66
import threading
77
import traceback
8-
from concurrent.futures import ThreadPoolExecutor
8+
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
99
from numbers import Number
1010
from operator import add
1111
from time import sleep
@@ -508,17 +508,14 @@ async def f(dask_worker=None):
508508
@gen_cluster(client=True, nthreads=[])
509509
async def test_Executor(c, s):
510510
with ThreadPoolExecutor(2) as e:
511-
w = Worker(s.address, executor=e)
512-
assert w.executor is e
513-
w = await w
511+
async with Worker(s.address, executor=e) as w:
512+
assert w.executor is e
514513

515-
future = c.submit(inc, 1)
516-
result = await future
517-
assert result == 2
518-
519-
assert e._threads # had to do some work
514+
future = c.submit(inc, 1)
515+
result = await future
516+
assert result == 2
520517

521-
await w.close()
518+
assert e._threads # had to do some work
522519

523520

524521
@pytest.mark.skip(
@@ -2028,6 +2025,34 @@ def get_thread_name():
20282025
assert "Dask-GPU-Threads" in gpu_result
20292026

20302027

2028+
@gen_cluster(client=True)
2029+
async def test_process_executor(c, s, a, b):
2030+
with ProcessPoolExecutor() as e:
2031+
a.executors["processes"] = e
2032+
b.executors["processes"] = e
2033+
2034+
future = c.submit(os.getpid, pure=False)
2035+
assert (await future) == os.getpid()
2036+
2037+
with dask.annotate(executor="processes"):
2038+
future = c.submit(os.getpid, pure=False)
2039+
2040+
assert (await future) != os.getpid()
2041+
2042+
2043+
@gen_cluster(client=True)
2044+
async def test_process_executor_kills_process(c, s, a, b):
2045+
with ProcessPoolExecutor() as e:
2046+
a.executors["processes"] = e
2047+
b.executors["processes"] = e
2048+
2049+
with dask.annotate(executor="processes", retries=1):
2050+
future = c.submit(sys.exit, 1)
2051+
2052+
exc = await future.exception()
2053+
assert "SystemExit(1)" in repr(exc)
2054+
2055+
20312056
def assert_task_states_on_worker(expected, worker):
20322057
for dep_key, expected_state in expected.items():
20332058
assert dep_key in worker.tasks, (worker.name, dep_key, worker.tasks)

distributed/worker.py

Lines changed: 64 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from typing import Dict, Iterable, Optional
2020

2121
from tlz import first, keymap, merge, pluck # noqa: F401
22-
from tornado import gen
2322
from tornado.ioloop import IOLoop, PeriodicCallback
2423

2524
import dask
@@ -438,6 +437,7 @@ def __init__(
438437

439438
self.active_threads_lock = threading.Lock()
440439
self.active_threads = dict()
440+
self.active_keys = set()
441441
self.profile_keys = defaultdict(profile.create)
442442
self.profile_keys_history = deque(maxlen=3600)
443443
self.profile_recent = profile.create()
@@ -968,16 +968,14 @@ async def heartbeat(self):
968968
logger.debug("Heartbeat: %s", self.address)
969969
try:
970970
start = time()
971-
with self.active_threads_lock:
972-
active_keys = list(self.active_threads.values())
973971
response = await retry_operation(
974972
self.scheduler.heartbeat_worker,
975973
address=self.contact_address,
976974
now=start,
977975
metrics=await self.get_metrics(),
978976
executing={
979977
key: start - self.tasks[key].start_time
980-
for key in active_keys
978+
for key in self.active_keys
981979
if key in self.tasks
982980
},
983981
)
@@ -2686,41 +2684,6 @@ def release_key(
26862684
# Execute Task #
26872685
################
26882686

2689-
# FIXME: this breaks if changed to async def...
2690-
# xref: https://github.com/dask/distributed/issues/3938
2691-
@gen.coroutine
2692-
def executor_submit(self, key, function, args=(), kwargs=None, executor=None):
2693-
"""Safely run function in thread pool executor
2694-
2695-
We've run into issues running concurrent.future futures within
2696-
tornado. Apparently it's advantageous to use timeouts and periodic
2697-
callbacks to ensure things run smoothly. This can get tricky, so we
2698-
pull it off into an separate method.
2699-
"""
2700-
executor = executor or self.executors["default"]
2701-
job_counter[0] += 1
2702-
# logger.info("%s:%d Starts job %d, %s", self.ip, self.port, i, key)
2703-
kwargs = kwargs or {}
2704-
future = executor.submit(function, *args, **kwargs)
2705-
pc = PeriodicCallback(
2706-
lambda: logger.debug("future state: %s - %s", key, future._state), 1000
2707-
)
2708-
ts = self.tasks.get(key)
2709-
if ts is not None:
2710-
ts.start_time = time()
2711-
pc.start()
2712-
try:
2713-
yield future
2714-
finally:
2715-
pc.stop()
2716-
if ts is not None:
2717-
ts.stop_time = time()
2718-
2719-
result = future.result()
2720-
2721-
# logger.info("Finish job %d, %s", i, key)
2722-
raise gen.Return(result)
2723-
27242687
def run(self, comm, function, args=(), wait=True, kwargs=None):
27252688
return run(self, comm, function=function, args=args, kwargs=kwargs, wait=wait)
27262689

@@ -2782,19 +2745,16 @@ async def actor_execute(
27822745
if iscoroutinefunction(func):
27832746
result = await func(*args, **kwargs)
27842747
elif separate_thread:
2785-
result = await self.executor_submit(
2786-
name,
2748+
result = await self.loop.run_in_executor(
2749+
self.executors["actor"],
27872750
apply_function_actor,
2788-
args=(
2789-
func,
2790-
args,
2791-
kwargs,
2792-
self.execution_state,
2793-
name,
2794-
self.active_threads,
2795-
self.active_threads_lock,
2796-
),
2797-
executor=self.executors["actor"],
2751+
func,
2752+
args,
2753+
kwargs,
2754+
self.execution_state,
2755+
name,
2756+
self.active_threads,
2757+
self.active_threads_lock,
27982758
)
27992759
else:
28002760
result = func(*args, **kwargs)
@@ -2946,11 +2906,14 @@ async def execute(self, key, report=False):
29462906
executor,
29472907
) # TODO: comment out?
29482908
assert key == ts.key
2909+
self.active_keys.add(ts.key)
29492910
try:
2950-
result = await self.executor_submit(
2951-
ts.key,
2952-
apply_function,
2953-
args=(
2911+
e = self.executors[executor]
2912+
ts.start_time = time()
2913+
if "ThreadPoolExecutor" in str(type(e)):
2914+
result = await self.loop.run_in_executor(
2915+
e,
2916+
apply_function,
29542917
function,
29552918
args2,
29562919
kwargs2,
@@ -2959,12 +2922,32 @@ async def execute(self, key, report=False):
29592922
self.active_threads,
29602923
self.active_threads_lock,
29612924
self.scheduler_delay,
2962-
),
2963-
executor=self.executors[executor],
2964-
)
2925+
)
2926+
else:
2927+
try:
2928+
start = time() + self.scheduler_delay
2929+
result = await self.loop.run_in_executor(
2930+
e,
2931+
apply_function_simple,
2932+
function,
2933+
args2,
2934+
kwargs2,
2935+
self.scheduler_delay,
2936+
)
2937+
except BaseException as e:
2938+
msg = error_message(e)
2939+
msg["op"] = "task-erred"
2940+
msg["actual-exception"] = e
2941+
msg["start"] = start
2942+
msg["stop"] = time() + self.scheduler_delay
2943+
msg["thread"] = None
2944+
result = msg
2945+
29652946
except RuntimeError as e:
29662947
executor_error = e
29672948
raise
2949+
finally:
2950+
self.active_keys.discard(ts.key)
29682951

29692952
# We'll need to check again for the task state since it may have
29702953
# changed since the execution was kicked off. In particular, it may
@@ -3854,6 +3837,27 @@ def apply_function(
38543837
thread_state.start_time = time()
38553838
thread_state.execution_state = execution_state
38563839
thread_state.key = key
3840+
3841+
msg = apply_function_simple(function, args, kwargs, time_delay)
3842+
3843+
with active_threads_lock:
3844+
del active_threads[ident]
3845+
return msg
3846+
3847+
3848+
def apply_function_simple(
3849+
function,
3850+
args,
3851+
kwargs,
3852+
time_delay,
3853+
):
3854+
"""Run a function, collect information
3855+
3856+
Returns
3857+
-------
3858+
msg: dictionary with status, result/error, timings, etc..
3859+
"""
3860+
ident = threading.get_ident()
38573861
start = time()
38583862
try:
38593863
result = function(*args, **kwargs)
@@ -3874,8 +3878,6 @@ def apply_function(
38743878
msg["start"] = start + time_delay
38753879
msg["stop"] = end + time_delay
38763880
msg["thread"] = ident
3877-
with active_threads_lock:
3878-
del active_threads[ident]
38793881
return msg
38803882

38813883

@@ -4020,9 +4022,8 @@ async def run(server, comm, function, args=(), kwargs=None, is_coro=None, wait=T
40204022
pass
40214023
else:
40224024

4023-
@gen.coroutine
4024-
def gpu_metric(worker):
4025-
result = yield offload(nvml.real_time)
4025+
async def gpu_metric(worker):
4026+
result = await offload(nvml.real_time)
40264027
return result
40274028

40284029
DEFAULT_METRICS["gpu"] = gpu_metric

0 commit comments

Comments
 (0)