diff --git a/sdks/python/apache_beam/runners/direct/direct_runner.py b/sdks/python/apache_beam/runners/direct/direct_runner.py index 8b8937653688..9d9d75833c89 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner.py @@ -52,6 +52,7 @@ from apache_beam.transforms.ptransform import PTransform from apache_beam.transforms.timeutil import TimeDomain from apache_beam.typehints import trivial_inference +from apache_beam.utils.interactive_utils import is_in_ipython __all__ = ['BundleBasedDirectRunner', 'DirectRunner', 'SwitchingDirectRunner'] @@ -114,7 +115,11 @@ class _PrismRunnerSupportVisitor(PipelineVisitor): """Visitor determining if a Pipeline can be run on the PrismRunner.""" def accept(self, pipeline): self.supported_by_prism_runner = True - pipeline.visit(self) + # TODO(https://github.com/apache/beam/issues/33623): Prism currently does not support interactive mode + if is_in_ipython(): + self.supported_by_prism_runner = False + else: + pipeline.visit(self) return self.supported_by_prism_runner def visit_transform(self, applied_ptransform): @@ -136,40 +141,32 @@ def visit_transform(self, applied_ptransform): if userstate.is_stateful_dofn(dofn): # https://github.com/apache/beam/issues/32786 - # Remove once Real time clock is used. - _, timer_specs = userstate.get_dofn_specs(dofn) + state_specs, timer_specs = userstate.get_dofn_specs(dofn) for timer in timer_specs: if timer.time_domain == TimeDomain.REAL_TIME: self.supported_by_prism_runner = False - tryingPrism = False + for state in state_specs: + if isinstance(state, userstate.CombiningValueStateSpec): + self.supported_by_prism_runner = False + + # Use BundleBasedDirectRunner if other runners are missing needed features. + runner = BundleBasedDirectRunner() + # Check whether all transforms used in the pipeline are supported by the - # FnApiRunner, and the pipeline was not meant to be run as streaming. - if _FnApiRunnerSupportVisitor().accept(pipeline): - from apache_beam.portability.api import beam_provision_api_pb2 - from apache_beam.runners.portability.fn_api_runner import fn_runner - from apache_beam.runners.portability.portable_runner import JobServiceHandle - all_options = options.get_all_options() - encoded_options = JobServiceHandle.encode_pipeline_options(all_options) - provision_info = fn_runner.ExtendedProvisionInfo( - beam_provision_api_pb2.ProvisionInfo( - pipeline_options=encoded_options)) - runner = fn_runner.FnApiRunner(provision_info=provision_info) - elif _PrismRunnerSupportVisitor().accept(pipeline): + # PrismRunner + if _PrismRunnerSupportVisitor().accept(pipeline): _LOGGER.info('Running pipeline with PrismRunner.') from apache_beam.runners.portability import prism_runner runner = prism_runner.PrismRunner() - tryingPrism = True - else: - runner = BundleBasedDirectRunner() - if tryingPrism: try: pr = runner.run_pipeline(pipeline, options) # This is non-blocking, so if the state is *already* finished, something # probably failed on job submission. if pr.state.is_terminal() and pr.state != PipelineState.DONE: _LOGGER.info( - 'Pipeline failed on PrismRunner, falling back toDirectRunner.') + 'Pipeline failed on PrismRunner, falling back to DirectRunner.') runner = BundleBasedDirectRunner() else: return pr @@ -181,6 +178,19 @@ def visit_transform(self, applied_ptransform): _LOGGER.info('Falling back to DirectRunner') runner = BundleBasedDirectRunner() + # Check whether all transforms used in the pipeline are supported by the + # FnApiRunner, and the pipeline was not meant to be run as streaming. + if _FnApiRunnerSupportVisitor().accept(pipeline): + from apache_beam.portability.api import beam_provision_api_pb2 + from apache_beam.runners.portability.fn_api_runner import fn_runner + from apache_beam.runners.portability.portable_runner import JobServiceHandle + all_options = options.get_all_options() + encoded_options = JobServiceHandle.encode_pipeline_options(all_options) + provision_info = fn_runner.ExtendedProvisionInfo( + beam_provision_api_pb2.ProvisionInfo( + pipeline_options=encoded_options)) + runner = fn_runner.FnApiRunner(provision_info=provision_info) + return runner.run_pipeline(pipeline, options) diff --git a/sdks/python/apache_beam/runners/direct/direct_runner_test.py b/sdks/python/apache_beam/runners/direct/direct_runner_test.py index 1af5f1bc7bea..a14eba851c48 100644 --- a/sdks/python/apache_beam/runners/direct/direct_runner_test.py +++ b/sdks/python/apache_beam/runners/direct/direct_runner_test.py @@ -47,9 +47,10 @@ class DirectPipelineResultTest(unittest.TestCase): def test_waiting_on_result_stops_executor_threads(self): pre_test_threads = set(t.ident for t in threading.enumerate()) - for runner in ['DirectRunner', - 'BundleBasedDirectRunner', - 'SwitchingDirectRunner']: + for runner in [ + 'BundleBasedDirectRunner', + 'apache_beam.runners.portability.fn_api_runner.fn_runner.FnApiRunner' + ]: pipeline = test_pipeline.TestPipeline(runner=runner) _ = (pipeline | beam.Create([{'foo': 'bar'}])) result = pipeline.run()