3131from apache_beam .pipeline import PipelineVisitor
3232from apache_beam .runners .dask .overrides import dask_overrides
3333from apache_beam .runners .dask .transform_evaluator import TRANSLATIONS
34+ from apache_beam .runners .dask .transform_evaluator import DaskBagWindowedIterator
35+ from apache_beam .runners .dask .transform_evaluator import Flatten
3436from apache_beam .runners .dask .transform_evaluator import NoOp
3537from apache_beam .runners .direct .direct_runner import BundleBasedDirectRunner
3638from apache_beam .runners .runner import PipelineResult
3739from apache_beam .runners .runner import PipelineState
40+ from apache_beam .transforms .sideinputs import SideInputMap
3841from apache_beam .utils .interactive_utils import is_in_notebook
3942
43+ try :
44+ # Added to try to prevent threading related issues, see
45+ # https://github.com/pytest-dev/pytest/issues/3216#issuecomment-1502451456
46+ import dask .distributed as ddist
47+ except ImportError :
48+ ddist = {}
49+
4050
4151class DaskOptions (PipelineOptions ):
4252 @staticmethod
@@ -86,10 +96,9 @@ def _add_argparse_args(cls, parser: argparse.ArgumentParser) -> None:
8696
8797@dataclasses .dataclass
8898class DaskRunnerResult (PipelineResult ):
89- from dask import distributed
9099
91- client : distributed .Client
92- futures : t .Sequence [distributed .Future ]
100+ client : ddist .Client
101+ futures : t .Sequence [ddist .Future ]
93102
94103 def __post_init__ (self ):
95104 super ().__init__ (PipelineState .RUNNING )
@@ -99,8 +108,16 @@ def wait_until_finish(self, duration=None) -> str:
99108 if duration is not None :
100109 # Convert milliseconds to seconds
101110 duration /= 1000
102- self .client .wait_for_workers (timeout = duration )
103- self .client .gather (self .futures , errors = 'raise' )
111+ for _ in ddist .as_completed (self .futures ,
112+ timeout = duration ,
113+ with_results = True ):
114+ # without gathering results, worker errors are not raised on the client:
115+ # https://distributed.dask.org/en/stable/resilience.html#user-code-failures
116+ # so we want to gather results to raise errors client-side, but we do
117+ # not actually need to use the results here, so we just pass. to gather,
118+ # we use the iterative `as_completed(..., with_results=True)`, instead
119+ # of aggregate `client.gather`, to minimize memory footprint of results.
120+ pass
104121 self ._state = PipelineState .DONE
105122 except : # pylint: disable=broad-except
106123 self ._state = PipelineState .FAILED
@@ -133,6 +150,7 @@ def visit_transform(self, transform_node: AppliedPTransform) -> None:
133150 op_class = TRANSLATIONS .get (transform_node .transform .__class__ , NoOp )
134151 op = op_class (transform_node )
135152
153+ op_kws = {"input_bag" : None , "side_inputs" : None }
136154 inputs = list (transform_node .inputs )
137155 if inputs :
138156 bag_inputs = []
@@ -144,13 +162,28 @@ def visit_transform(self, transform_node: AppliedPTransform) -> None:
144162 if prev_op in self .bags :
145163 bag_inputs .append (self .bags [prev_op ])
146164
147- if len (bag_inputs ) == 1 :
148- self .bags [transform_node ] = op .apply (bag_inputs [0 ])
165+ # Input to `Flatten` could be of length 1, e.g. a single-element
166+ # tuple: `(pcoll, ) | beam.Flatten()`. If so, we still pass it as
167+ # an iterable, because `Flatten.apply` always takes an iterable.
168+ if len (bag_inputs ) == 1 and not isinstance (op , Flatten ):
169+ op_kws ["input_bag" ] = bag_inputs [0 ]
149170 else :
150- self .bags [transform_node ] = op .apply (bag_inputs )
171+ op_kws ["input_bag" ] = bag_inputs
172+
173+ side_inputs = list (transform_node .side_inputs )
174+ if side_inputs :
175+ bag_side_inputs = []
176+ for si in side_inputs :
177+ si_asbag = self .bags .get (si .pvalue .producer )
178+ bag_side_inputs .append (
179+ SideInputMap (
180+ type (si ),
181+ si ._view_options (),
182+ DaskBagWindowedIterator (si_asbag , si ._window_mapping_fn )))
183+
184+ op_kws ["side_inputs" ] = bag_side_inputs
151185
152- else :
153- self .bags [transform_node ] = op .apply (None )
186+ self .bags [transform_node ] = op .apply (** op_kws )
154187
155188 return DaskBagVisitor ()
156189
@@ -159,6 +192,8 @@ def is_fnapi_compatible():
159192 return False
160193
161194 def run_pipeline (self , pipeline , options ):
195+ import dask
196+
162197 # TODO(alxr): Create interactive notebook support.
163198 if is_in_notebook ():
164199 raise NotImplementedError ('interactive support will come later!' )
@@ -177,6 +212,6 @@ def run_pipeline(self, pipeline, options):
177212
178213 dask_visitor = self .to_dask_bag_visitor ()
179214 pipeline .visit (dask_visitor )
180-
181- futures = client .compute (list ( dask_visitor . bags . values ()) )
215+ opt_graph = dask . optimize ( * list ( dask_visitor . bags . values ()))
216+ futures = client .compute (opt_graph )
182217 return DaskRunnerResult (client , futures )
0 commit comments