Skip to content

Commit 19397fc

Browse files
committed
handle empty compute and determ token for map
1 parent e353fb2 commit 19397fc

File tree

4 files changed

+55
-31
lines changed

4 files changed

+55
-31
lines changed

distributed/client.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -848,6 +848,12 @@ class _MapExpr(Expr):
848848
]
849849
_defaults = {"_cached_keys": None}
850850

851+
@property
852+
def deterministic_token(self):
853+
if not self.pure:
854+
self._determ_token = uuid.uuid4().hex
855+
return super().deterministic_token
856+
851857
@property
852858
def keys(self) -> Iterable[Key]:
853859
if self._cached_keys is not None:
@@ -3629,27 +3635,28 @@ def compute(
36293635
metadata = SpanMetadata(
36303636
collections=[get_collections_metadata(v) for v in variables]
36313637
)
3632-
3633-
expr = collections_to_expr(variables, optimize_graph, **kwargs)
3634-
from dask._expr import FinalizeCompute
3635-
3636-
expr = FinalizeCompute(expr)
3637-
3638-
expr = expr.optimize()
3639-
names = list(flatten(expr.__dask_keys__()))
3640-
3641-
futures_dict = self._graph_to_futures(
3642-
expr,
3643-
names,
3644-
workers=workers,
3645-
allow_other_workers=allow_other_workers,
3646-
resources=resources,
3647-
retries=retries,
3648-
user_priority=priority,
3649-
fifo_timeout=fifo_timeout,
3650-
actors=actors,
3651-
span_metadata=metadata,
3652-
)
3638+
futures_dict = {}
3639+
if variables:
3640+
expr = collections_to_expr(variables, optimize_graph, **kwargs)
3641+
from dask._expr import FinalizeCompute
3642+
3643+
expr = FinalizeCompute(expr)
3644+
3645+
expr = expr.optimize()
3646+
names = list(flatten(expr.__dask_keys__()))
3647+
3648+
futures_dict = self._graph_to_futures(
3649+
expr,
3650+
names,
3651+
workers=workers,
3652+
allow_other_workers=allow_other_workers,
3653+
resources=resources,
3654+
retries=retries,
3655+
user_priority=priority,
3656+
fifo_timeout=fifo_timeout,
3657+
actors=actors,
3658+
span_metadata=metadata,
3659+
)
36533660

36543661
i = 0
36553662
futures = []

distributed/tests/test_client.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,13 @@ async def test_custom_key_with_batches(c, s, a, b):
270270
await wait(futs)
271271

272272

273+
@gen_cluster(client=True)
274+
async def test_compute_no_collection_or_future(c, s, *workers):
275+
assert c.compute(1) == 1
276+
277+
assert await c.gather(c.compute((1, delayed(inc)(1)))) == [1, 2]
278+
279+
273280
@gen_cluster(client=True)
274281
async def test_compute_retries(c, s, a, b):
275282
args = [ZeroDivisionError("one"), ZeroDivisionError("two"), 3]

distributed/tests/test_scheduler.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2823,11 +2823,16 @@ async def test_default_task_duration_splits(c, s, a, b):
28232823
npart = 10
28242824
df = dd.from_pandas(pd.DataFrame({"A": range(100), "B": 1}), npartitions=npart)
28252825
with dask.config.set({"dataframe.shuffle.method": "tasks"}):
2826-
graph = df.shuffle(
2827-
"A",
2828-
# If we don't have enough partitions, we'll fall back to a simple shuffle
2829-
max_branch=npart - 1,
2830-
).sum()
2826+
graph = (
2827+
df.shuffle(
2828+
"A",
2829+
# If we don't have enough partitions, we'll fall back to a
2830+
# simple shuffle
2831+
max_branch=npart - 1,
2832+
)
2833+
# Block optimizer from killing the shuffle
2834+
.map_partitions(lambda x: len(x)).sum()
2835+
)
28312836
fut = c.compute(graph)
28322837
await wait(fut)
28332838

distributed/tests/test_steal.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,11 +1082,16 @@ async def test_blocklist_shuffle_split(c, s, a, b):
10821082
npart = 10
10831083
df = dd.from_pandas(pd.DataFrame({"A": range(100), "B": 1}), npartitions=npart)
10841084
with dask.config.set({"dataframe.shuffle.method": "tasks"}):
1085-
graph = df.shuffle(
1086-
"A",
1087-
# If we don't have enough partitions, we'll fall back to a simple shuffle
1088-
max_branch=npart - 1,
1089-
).sum()
1085+
graph = (
1086+
df.shuffle(
1087+
"A",
1088+
# If we don't have enough partitions, we'll fall back to a
1089+
# simple shuffle
1090+
max_branch=npart - 1,
1091+
)
1092+
# Block optimizer from killing the shuffle
1093+
.map_partitions(lambda x: len(x)).sum()
1094+
)
10901095
res = c.compute(graph)
10911096

10921097
while not s.tasks:

0 commit comments

Comments
 (0)