Skip to content

Commit 905604d

Browse files
authored
Ensure client submit does not serialize unnecessarily (#9057)
1 parent 01ab4e9 commit 905604d

File tree

2 files changed

+59
-1
lines changed

2 files changed

+59
-1
lines changed

distributed/client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2146,7 +2146,10 @@ def submit(
21462146
*(parse_input(a) for a in args),
21472147
**{k: parse_input(v) for k, v in kwargs.items()},
21482148
)
2149-
}
2149+
},
2150+
# We'd like to avoid hashing/tokenizing all of the above.
2151+
# The LLGExpr in this situation is as unique as it'll get.
2152+
_determ_token=uuid.uuid4().hex,
21502153
)
21512154
futures = self._graph_to_futures(
21522155
expr,

distributed/tests/test_client.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8295,3 +8295,58 @@ async def test_adjust_heartbeat(c, s):
82958295
await asyncio.sleep(0.1)
82968296
assert heartbeat_pc.callback_time == initial_value
82978297
assert heartbeat_pc.callback_time == scheduler_info_pc.callback_time
8298+
8299+
8300+
class CountSerialized:
8301+
"""Simple wrapper on an object that counts the number of times it's
8302+
serialized."""
8303+
8304+
global_count = 0
8305+
8306+
def __init__(self, x):
8307+
self.x = x
8308+
self.local_count = 0
8309+
8310+
def __reduce__(self):
8311+
CountSerialized.global_count += 1
8312+
self.local_count += 1
8313+
return (CountSerialized, (self.x,))
8314+
8315+
def __dask_tokenize__(self):
8316+
# If no tokenization is registered, we'll pickle many times to get a
8317+
# deterministic token
8318+
return (self.__class__, self.x)
8319+
8320+
8321+
def _task(foo, bar):
8322+
"""Some dummy task on CountSerialized objects."""
8323+
return foo.x == bar.x == 1
8324+
8325+
8326+
def _func(task_and_args):
8327+
f, args = task_and_args
8328+
return f(*args)
8329+
8330+
8331+
@pytest.mark.parametrize("use_lambda", [True, False])
8332+
@gen_cluster(client=True, nthreads=[("", 1)])
8333+
async def test_count_serialization(c, s, a, use_lambda):
8334+
CountSerialized.global_count = 0
8335+
x = CountSerialized(1)
8336+
if use_lambda:
8337+
task_and_args = _task, [x, x]
8338+
await c.submit(lambda: _func(task_and_args))
8339+
# The lambda will trigger an exception during serialization that will
8340+
# escalate and will attempt to serialize this repeatedly until falling
8341+
# back to cloudpickle.
8342+
# It also requires us to tokenize the lambda which again uses pickle.
8343+
# The same happens if _task or _func is defined in local scope.
8344+
expected_local = 3
8345+
else:
8346+
await c.submit(_task, x, x)
8347+
expected_local = 1
8348+
8349+
# once to the worker
8350+
expected_global = expected_local + 1
8351+
assert x.local_count <= expected_local
8352+
assert x.global_count <= expected_global

0 commit comments

Comments
 (0)