Skip to content

Commit 0e3b344

Browse files
authored
Fix parsing of input arguments for Client.map (#9071)
1 parent 01ea1eb commit 0e3b344

File tree

2 files changed

+35
-7
lines changed

2 files changed

+35
-7
lines changed

distributed/client.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -892,7 +892,7 @@ def _layer(self) -> dict[Key, GraphNode]:
892892

893893
if not self.kwargs:
894894
dsk = {
895-
key: Task(key, self.func, *args)
895+
key: Task(key, self.func, *parse_input(args)) # type: ignore[misc]
896896
for key, args in zip(self.keys, zip(*self.iterables))
897897
}
898898

@@ -907,12 +907,17 @@ def _layer(self) -> dict[Key, GraphNode]:
907907
else:
908908
kwargs2[k] = parse_input(v)
909909

910-
dsk.update(
911-
{
912-
key: Task(key, self.func, *args, **kwargs2)
913-
for key, args in zip(self.keys, zip(*self.iterables))
914-
}
915-
)
910+
dsk.update(
911+
{
912+
key: Task(
913+
key,
914+
self.func,
915+
*parse_input(args), # type: ignore[misc]
916+
**kwargs2,
917+
)
918+
for key, args in zip(self.keys, zip(*self.iterables))
919+
}
920+
)
916921
return dsk
917922

918923

distributed/tests/test_client.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8397,3 +8397,26 @@ async def test_count_serialization(c, s, a, use_lambda):
83978397
expected_global = expected_local + 1
83988398
assert x.local_count <= expected_local
83998399
assert x.global_count <= expected_global
8400+
8401+
8402+
@pytest.mark.parametrize("use_kwarg", [False, "simple", "future"])
8403+
@gen_cluster(client=True)
8404+
async def test_map_accepts_nested_futures(c, s, a, b, use_kwarg):
8405+
def reducer(futs, *, offset=0, **kwargs):
8406+
return sum(futs) + offset
8407+
8408+
f1 = c.submit(lambda: 10)
8409+
f2 = c.submit(lambda: 20)
8410+
offset = None
8411+
if use_kwarg == "simple":
8412+
future = c.map(reducer, [[f1, f2]], foo=True)[0]
8413+
elif use_kwarg == "future":
8414+
offset = c.submit(lambda: 1)
8415+
8416+
future = c.map(reducer, [[f1, f2]], offset=offset)[0]
8417+
8418+
else:
8419+
future = c.map(reducer, [[f1, f2]])[0]
8420+
8421+
result = await future.result()
8422+
assert result == 30 if not offset else 31

0 commit comments

Comments
 (0)