Skip to content

Commit e939be3

Browse files
alxmrspabloemcisaacstern
authored
Windowing Support for the Dask Runner (#32941)
Windowing Support for the Dask Runner --------- Co-authored-by: Pablo E <[email protected]> Co-authored-by: Pablo <[email protected]> Co-authored-by: Charles Stern <[email protected]>
1 parent bff2cbb commit e939be3

File tree

9 files changed

+558
-42
lines changed

9 files changed

+558
-42
lines changed

.github/workflows/dask_runner_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ jobs:
7878
run: pip install tox
7979
- name: Install SDK with dask
8080
working-directory: ./sdks/python
81-
run: pip install setuptools --upgrade && pip install -e .[gcp,dask,test]
81+
run: pip install setuptools --upgrade && pip install -e .[dask,test,dataframes]
8282
- name: Run tests basic unix
8383
if: startsWith(matrix.os, 'ubuntu') || startsWith(matrix.os, 'macos')
8484
working-directory: ./sdks/python

sdks/python/apache_beam/runners/dask/dask_runner.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,22 @@
3131
from apache_beam.pipeline import PipelineVisitor
3232
from apache_beam.runners.dask.overrides import dask_overrides
3333
from apache_beam.runners.dask.transform_evaluator import TRANSLATIONS
34+
from apache_beam.runners.dask.transform_evaluator import DaskBagWindowedIterator
35+
from apache_beam.runners.dask.transform_evaluator import Flatten
3436
from apache_beam.runners.dask.transform_evaluator import NoOp
3537
from apache_beam.runners.direct.direct_runner import BundleBasedDirectRunner
3638
from apache_beam.runners.runner import PipelineResult
3739
from apache_beam.runners.runner import PipelineState
40+
from apache_beam.transforms.sideinputs import SideInputMap
3841
from apache_beam.utils.interactive_utils import is_in_notebook
3942

43+
try:
44+
# Added to try to prevent threading related issues, see
45+
# https://github.com/pytest-dev/pytest/issues/3216#issuecomment-1502451456
46+
import dask.distributed as ddist
47+
except ImportError:
48+
ddist = {}
49+
4050

4151
class DaskOptions(PipelineOptions):
4252
@staticmethod
@@ -86,10 +96,9 @@ def _add_argparse_args(cls, parser: argparse.ArgumentParser) -> None:
8696

8797
@dataclasses.dataclass
8898
class DaskRunnerResult(PipelineResult):
89-
from dask import distributed
9099

91-
client: distributed.Client
92-
futures: t.Sequence[distributed.Future]
100+
client: ddist.Client
101+
futures: t.Sequence[ddist.Future]
93102

94103
def __post_init__(self):
95104
super().__init__(PipelineState.RUNNING)
@@ -99,8 +108,16 @@ def wait_until_finish(self, duration=None) -> str:
99108
if duration is not None:
100109
# Convert milliseconds to seconds
101110
duration /= 1000
102-
self.client.wait_for_workers(timeout=duration)
103-
self.client.gather(self.futures, errors='raise')
111+
for _ in ddist.as_completed(self.futures,
112+
timeout=duration,
113+
with_results=True):
114+
# without gathering results, worker errors are not raised on the client:
115+
# https://distributed.dask.org/en/stable/resilience.html#user-code-failures
116+
# so we want to gather results to raise errors client-side, but we do
117+
# not actually need to use the results here, so we just pass. to gather,
118+
# we use the iterative `as_completed(..., with_results=True)`, instead
119+
# of aggregate `client.gather`, to minimize memory footprint of results.
120+
pass
104121
self._state = PipelineState.DONE
105122
except: # pylint: disable=broad-except
106123
self._state = PipelineState.FAILED
@@ -133,6 +150,7 @@ def visit_transform(self, transform_node: AppliedPTransform) -> None:
133150
op_class = TRANSLATIONS.get(transform_node.transform.__class__, NoOp)
134151
op = op_class(transform_node)
135152

153+
op_kws = {"input_bag": None, "side_inputs": None}
136154
inputs = list(transform_node.inputs)
137155
if inputs:
138156
bag_inputs = []
@@ -144,13 +162,28 @@ def visit_transform(self, transform_node: AppliedPTransform) -> None:
144162
if prev_op in self.bags:
145163
bag_inputs.append(self.bags[prev_op])
146164

147-
if len(bag_inputs) == 1:
148-
self.bags[transform_node] = op.apply(bag_inputs[0])
165+
# Input to `Flatten` could be of length 1, e.g. a single-element
166+
# tuple: `(pcoll, ) | beam.Flatten()`. If so, we still pass it as
167+
# an iterable, because `Flatten.apply` always takes an iterable.
168+
if len(bag_inputs) == 1 and not isinstance(op, Flatten):
169+
op_kws["input_bag"] = bag_inputs[0]
149170
else:
150-
self.bags[transform_node] = op.apply(bag_inputs)
171+
op_kws["input_bag"] = bag_inputs
172+
173+
side_inputs = list(transform_node.side_inputs)
174+
if side_inputs:
175+
bag_side_inputs = []
176+
for si in side_inputs:
177+
si_asbag = self.bags.get(si.pvalue.producer)
178+
bag_side_inputs.append(
179+
SideInputMap(
180+
type(si),
181+
si._view_options(),
182+
DaskBagWindowedIterator(si_asbag, si._window_mapping_fn)))
183+
184+
op_kws["side_inputs"] = bag_side_inputs
151185

152-
else:
153-
self.bags[transform_node] = op.apply(None)
186+
self.bags[transform_node] = op.apply(**op_kws)
154187

155188
return DaskBagVisitor()
156189

@@ -159,6 +192,8 @@ def is_fnapi_compatible():
159192
return False
160193

161194
def run_pipeline(self, pipeline, options):
195+
import dask
196+
162197
# TODO(alxr): Create interactive notebook support.
163198
if is_in_notebook():
164199
raise NotImplementedError('interactive support will come later!')
@@ -177,6 +212,6 @@ def run_pipeline(self, pipeline, options):
177212

178213
dask_visitor = self.to_dask_bag_visitor()
179214
pipeline.visit(dask_visitor)
180-
181-
futures = client.compute(list(dask_visitor.bags.values()))
215+
opt_graph = dask.optimize(*list(dask_visitor.bags.values()))
216+
futures = client.compute(opt_graph)
182217
return DaskRunnerResult(client, futures)

0 commit comments

Comments
 (0)