Skip to content

Commit 78071e2

Browse files
authored
Merge pull request #34583 More eagerly merge compatible environments.
2 parents 7136380 + 01b5bb9 commit 78071e2

File tree

10 files changed

+474
-237
lines changed

10 files changed

+474
-237
lines changed

sdks/python/apache_beam/pipeline.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@
8787
from apache_beam.portability import common_urns
8888
from apache_beam.portability.api import beam_runner_api_pb2
8989
from apache_beam.runners import PipelineRunner
90-
from apache_beam.runners import common
9190
from apache_beam.runners import create_runner
91+
from apache_beam.runners import pipeline_utils
9292
from apache_beam.transforms import ParDo
9393
from apache_beam.transforms import ptransform
9494
from apache_beam.transforms.display import DisplayData
@@ -1019,7 +1019,7 @@ def merge_compatible_environments(proto):
10191019
10201020
Mutates proto as contexts may have references to proto.components.
10211021
"""
1022-
common.merge_common_environments(proto, inplace=True)
1022+
pipeline_utils.merge_common_environments(proto, inplace=True)
10231023

10241024
@staticmethod
10251025
def from_runner_api(

sdks/python/apache_beam/runners/common.py

-162
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222

2323
# pytype: skip-file
2424

25-
import collections
26-
import copy
2725
import logging
2826
import sys
2927
import threading
@@ -42,8 +40,6 @@
4240
from apache_beam.coders import coders
4341
from apache_beam.internal import util
4442
from apache_beam.options.value_provider import RuntimeValueProvider
45-
from apache_beam.portability import common_urns
46-
from apache_beam.portability.api import beam_runner_api_pb2
4743
from apache_beam.pvalue import TaggedOutput
4844
from apache_beam.runners.sdf_utils import NoOpWatermarkEstimatorProvider
4945
from apache_beam.runners.sdf_utils import RestrictionTrackerView
@@ -53,15 +49,13 @@
5349
from apache_beam.runners.sdf_utils import ThreadsafeWatermarkEstimator
5450
from apache_beam.transforms import DoFn
5551
from apache_beam.transforms import core
56-
from apache_beam.transforms import environments
5752
from apache_beam.transforms import userstate
5853
from apache_beam.transforms.core import RestrictionProvider
5954
from apache_beam.transforms.core import WatermarkEstimatorProvider
6055
from apache_beam.transforms.window import GlobalWindow
6156
from apache_beam.transforms.window import GlobalWindows
6257
from apache_beam.transforms.window import TimestampedValue
6358
from apache_beam.transforms.window import WindowFn
64-
from apache_beam.typehints import typehints
6559
from apache_beam.typehints.batch import BatchConverter
6660
from apache_beam.utils.counters import Counter
6761
from apache_beam.utils.counters import CounterName
@@ -1920,159 +1914,3 @@ def windows(self):
19201914
raise AttributeError('windows not accessible in this context')
19211915
else:
19221916
return self.windowed_value.windows
1923-
1924-
1925-
def group_by_key_input_visitor(deterministic_key_coders=True):
1926-
# Importing here to avoid a circular dependency
1927-
# pylint: disable=wrong-import-order, wrong-import-position
1928-
from apache_beam.pipeline import PipelineVisitor
1929-
from apache_beam.transforms.core import GroupByKey
1930-
1931-
class GroupByKeyInputVisitor(PipelineVisitor):
1932-
"""A visitor that replaces `Any` element type for input `PCollection` of
1933-
a `GroupByKey` with a `KV` type.
1934-
1935-
TODO(BEAM-115): Once Python SDK is compatible with the new Runner API,
1936-
we could directly replace the coder instead of mutating the element type.
1937-
"""
1938-
def __init__(self, deterministic_key_coders=True):
1939-
self.deterministic_key_coders = deterministic_key_coders
1940-
1941-
def enter_composite_transform(self, transform_node):
1942-
self.visit_transform(transform_node)
1943-
1944-
def visit_transform(self, transform_node):
1945-
if isinstance(transform_node.transform, GroupByKey):
1946-
pcoll = transform_node.inputs[0]
1947-
pcoll.element_type = typehints.coerce_to_kv_type(
1948-
pcoll.element_type, transform_node.full_label)
1949-
pcoll.requires_deterministic_key_coder = (
1950-
self.deterministic_key_coders and transform_node.full_label)
1951-
key_type, value_type = pcoll.element_type.tuple_types
1952-
if transform_node.outputs:
1953-
key = next(iter(transform_node.outputs.keys()))
1954-
transform_node.outputs[key].element_type = typehints.KV[
1955-
key_type, typehints.Iterable[value_type]]
1956-
transform_node.outputs[key].requires_deterministic_key_coder = (
1957-
self.deterministic_key_coders and transform_node.full_label)
1958-
1959-
return GroupByKeyInputVisitor(deterministic_key_coders)
1960-
1961-
1962-
def validate_pipeline_graph(pipeline_proto):
1963-
"""Ensures this is a correctly constructed Beam pipeline.
1964-
"""
1965-
def get_coder(pcoll_id):
1966-
return pipeline_proto.components.coders[
1967-
pipeline_proto.components.pcollections[pcoll_id].coder_id]
1968-
1969-
def validate_transform(transform_id):
1970-
transform_proto = pipeline_proto.components.transforms[transform_id]
1971-
1972-
# Currently the only validation we perform is that GBK operations have
1973-
# their coders set properly.
1974-
if transform_proto.spec.urn == common_urns.primitives.GROUP_BY_KEY.urn:
1975-
if len(transform_proto.inputs) != 1:
1976-
raise ValueError("Unexpected number of inputs: %s" % transform_proto)
1977-
if len(transform_proto.outputs) != 1:
1978-
raise ValueError("Unexpected number of outputs: %s" % transform_proto)
1979-
input_coder = get_coder(next(iter(transform_proto.inputs.values())))
1980-
output_coder = get_coder(next(iter(transform_proto.outputs.values())))
1981-
if input_coder.spec.urn != common_urns.coders.KV.urn:
1982-
raise ValueError(
1983-
"Bad coder for input of %s: %s" % (transform_id, input_coder))
1984-
if output_coder.spec.urn != common_urns.coders.KV.urn:
1985-
raise ValueError(
1986-
"Bad coder for output of %s: %s" % (transform_id, output_coder))
1987-
output_values_coder = pipeline_proto.components.coders[
1988-
output_coder.component_coder_ids[1]]
1989-
if (input_coder.component_coder_ids[0] !=
1990-
output_coder.component_coder_ids[0] or
1991-
output_values_coder.spec.urn != common_urns.coders.ITERABLE.urn or
1992-
output_values_coder.component_coder_ids[0] !=
1993-
input_coder.component_coder_ids[1]):
1994-
raise ValueError(
1995-
"Incompatible input coder %s and output coder %s for transform %s" %
1996-
(transform_id, input_coder, output_coder))
1997-
elif transform_proto.spec.urn == common_urns.primitives.ASSIGN_WINDOWS.urn:
1998-
if not transform_proto.inputs:
1999-
raise ValueError("Missing input for transform: %s" % transform_proto)
2000-
elif transform_proto.spec.urn == common_urns.primitives.PAR_DO.urn:
2001-
if not transform_proto.inputs:
2002-
raise ValueError("Missing input for transform: %s" % transform_proto)
2003-
2004-
for t in transform_proto.subtransforms:
2005-
validate_transform(t)
2006-
2007-
for t in pipeline_proto.root_transform_ids:
2008-
validate_transform(t)
2009-
2010-
2011-
def merge_common_environments(pipeline_proto, inplace=False):
2012-
def dep_key(dep):
2013-
if dep.type_urn == common_urns.artifact_types.FILE.urn:
2014-
payload = beam_runner_api_pb2.ArtifactFilePayload.FromString(
2015-
dep.type_payload)
2016-
if payload.sha256:
2017-
type_info = 'sha256', payload.sha256
2018-
else:
2019-
type_info = 'path', payload.path
2020-
elif dep.type_urn == common_urns.artifact_types.URL.urn:
2021-
payload = beam_runner_api_pb2.ArtifactUrlPayload.FromString(
2022-
dep.type_payload)
2023-
if payload.sha256:
2024-
type_info = 'sha256', payload.sha256
2025-
else:
2026-
type_info = 'url', payload.url
2027-
else:
2028-
type_info = dep.type_urn, dep.type_payload
2029-
return type_info, dep.role_urn, dep.role_payload
2030-
2031-
def base_env_key(env):
2032-
return (
2033-
env.urn,
2034-
env.payload,
2035-
tuple(sorted(env.capabilities)),
2036-
tuple(sorted(env.resource_hints.items())),
2037-
tuple(sorted(dep_key(dep) for dep in env.dependencies)))
2038-
2039-
def env_key(env):
2040-
return tuple(
2041-
sorted(
2042-
base_env_key(e)
2043-
for e in environments.expand_anyof_environments(env)))
2044-
2045-
canonical_environments = collections.defaultdict(list)
2046-
for env_id, env in pipeline_proto.components.environments.items():
2047-
canonical_environments[env_key(env)].append(env_id)
2048-
2049-
if len(canonical_environments) == len(pipeline_proto.components.environments):
2050-
# All environments are already sufficiently distinct.
2051-
return pipeline_proto
2052-
2053-
environment_remappings = {
2054-
e: es[0]
2055-
for es in canonical_environments.values() for e in es
2056-
}
2057-
2058-
if not inplace:
2059-
pipeline_proto = copy.copy(pipeline_proto)
2060-
2061-
for t in pipeline_proto.components.transforms.values():
2062-
if t.environment_id not in pipeline_proto.components.environments:
2063-
# TODO(https://github.com/apache/beam/issues/30876): Remove this
2064-
# workaround.
2065-
continue
2066-
if t.environment_id:
2067-
t.environment_id = environment_remappings[t.environment_id]
2068-
for w in pipeline_proto.components.windowing_strategies.values():
2069-
if w.environment_id not in pipeline_proto.components.environments:
2070-
# TODO(https://github.com/apache/beam/issues/30876): Remove this
2071-
# workaround.
2072-
continue
2073-
if w.environment_id:
2074-
w.environment_id = environment_remappings[w.environment_id]
2075-
for e in set(pipeline_proto.components.environments.keys()) - set(
2076-
environment_remappings.values()):
2077-
del pipeline_proto.components.environments[e]
2078-
return pipeline_proto

sdks/python/apache_beam/runners/common_test.py

-59
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,8 @@
2626
from apache_beam.io.restriction_trackers import OffsetRestrictionTracker
2727
from apache_beam.io.watermark_estimators import ManualWatermarkEstimator
2828
from apache_beam.options.pipeline_options import PipelineOptions
29-
from apache_beam.portability.api import beam_runner_api_pb2
3029
from apache_beam.runners.common import DoFnSignature
3130
from apache_beam.runners.common import PerWindowInvoker
32-
from apache_beam.runners.common import merge_common_environments
33-
from apache_beam.runners.portability.expansion_service_test import FibTransform
3431
from apache_beam.runners.sdf_utils import SplitResultPrimary
3532
from apache_beam.runners.sdf_utils import SplitResultResidual
3633
from apache_beam.testing.test_pipeline import TestPipeline
@@ -587,61 +584,5 @@ def test_window_observing_split_on_window_boundary_round_down_on_last_window(
587584
self.assertEqual(stop_index, 2)
588585

589586

590-
class UtilitiesTest(unittest.TestCase):
591-
def test_equal_environments_merged(self):
592-
pipeline_proto = merge_common_environments(
593-
beam_runner_api_pb2.Pipeline(
594-
components=beam_runner_api_pb2.Components(
595-
environments={
596-
'a1': beam_runner_api_pb2.Environment(urn='A'),
597-
'a2': beam_runner_api_pb2.Environment(urn='A'),
598-
'b1': beam_runner_api_pb2.Environment(
599-
urn='B', payload=b'x'),
600-
'b2': beam_runner_api_pb2.Environment(
601-
urn='B', payload=b'x'),
602-
'b3': beam_runner_api_pb2.Environment(
603-
urn='B', payload=b'y'),
604-
},
605-
transforms={
606-
't1': beam_runner_api_pb2.PTransform(
607-
unique_name='t1', environment_id='a1'),
608-
't2': beam_runner_api_pb2.PTransform(
609-
unique_name='t2', environment_id='a2'),
610-
},
611-
windowing_strategies={
612-
'w1': beam_runner_api_pb2.WindowingStrategy(
613-
environment_id='b1'),
614-
'w2': beam_runner_api_pb2.WindowingStrategy(
615-
environment_id='b2'),
616-
})))
617-
self.assertEqual(len(pipeline_proto.components.environments), 3)
618-
self.assertTrue(('a1' in pipeline_proto.components.environments)
619-
^ ('a2' in pipeline_proto.components.environments))
620-
self.assertTrue(('b1' in pipeline_proto.components.environments)
621-
^ ('b2' in pipeline_proto.components.environments))
622-
self.assertEqual(
623-
len(
624-
set(
625-
t.environment_id
626-
for t in pipeline_proto.components.transforms.values())),
627-
1)
628-
self.assertEqual(
629-
len(
630-
set(
631-
w.environment_id for w in
632-
pipeline_proto.components.windowing_strategies.values())),
633-
1)
634-
635-
def test_external_merged(self):
636-
p = beam.Pipeline()
637-
# This transform recursively creates several external environments.
638-
_ = p | FibTransform(4)
639-
pipeline_proto = p.to_runner_api()
640-
# All our external environments are equal and consolidated.
641-
# We also have a placeholder "default" environment that has not been
642-
# resolved do anything concrete yet.
643-
self.assertEqual(len(pipeline_proto.components.environments), 2)
644-
645-
646587
if __name__ == '__main__':
647588
unittest.main()

sdks/python/apache_beam/runners/dataflow/dataflow_runner.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,10 @@
4343
from apache_beam.options.pipeline_options import WorkerOptions
4444
from apache_beam.portability import common_urns
4545
from apache_beam.portability.api import beam_runner_api_pb2
46-
from apache_beam.runners.common import group_by_key_input_visitor
47-
from apache_beam.runners.common import merge_common_environments
4846
from apache_beam.runners.dataflow.internal.clients import dataflow as dataflow_api
47+
from apache_beam.runners.pipeline_utils import group_by_key_input_visitor
48+
from apache_beam.runners.pipeline_utils import merge_common_environments
49+
from apache_beam.runners.pipeline_utils import merge_superset_dep_environments
4950
from apache_beam.runners.runner import PipelineResult
5051
from apache_beam.runners.runner import PipelineRunner
5152
from apache_beam.runners.runner import PipelineState
@@ -434,7 +435,8 @@ def run_pipeline(self, pipeline, options, pipeline_proto=None):
434435
self.proto_pipeline.components.environments[env_id].CopyFrom(
435436
environments.resolve_anyof_environment(
436437
env, common_urns.environments.DOCKER.urn))
437-
self.proto_pipeline = merge_common_environments(self.proto_pipeline)
438+
self.proto_pipeline = merge_common_environments(
439+
merge_superset_dep_environments(self.proto_pipeline))
438440

439441
# Optimize the pipeline if it not streaming and the pre_optimize
440442
# experiment is set.

sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@
3535
from apache_beam.pvalue import PCollection
3636
from apache_beam.runners import DataflowRunner
3737
from apache_beam.runners import TestDataflowRunner
38-
from apache_beam.runners import common
3938
from apache_beam.runners import create_runner
39+
from apache_beam.runners import pipeline_utils
4040
from apache_beam.runners.dataflow.dataflow_runner import DataflowPipelineResult
4141
from apache_beam.runners.dataflow.dataflow_runner import DataflowRuntimeException
4242
from apache_beam.runners.dataflow.dataflow_runner import _check_and_add_missing_options
@@ -316,7 +316,7 @@ def test_group_by_key_input_visitor_with_valid_inputs(self):
316316
applied = AppliedPTransform(
317317
None, beam.GroupByKey(), "label", {'pcoll': pcoll}, None, None)
318318
applied.outputs[None] = PCollection(None)
319-
common.group_by_key_input_visitor().visit_transform(applied)
319+
pipeline_utils.group_by_key_input_visitor().visit_transform(applied)
320320
self.assertEqual(
321321
pcoll.element_type, typehints.KV[typehints.Any, typehints.Any])
322322

@@ -332,7 +332,7 @@ def test_group_by_key_input_visitor_with_invalid_inputs(self):
332332
"Found .*")
333333
for pcoll in [pcoll1, pcoll2]:
334334
with self.assertRaisesRegex(ValueError, err_msg):
335-
common.group_by_key_input_visitor().visit_transform(
335+
pipeline_utils.group_by_key_input_visitor().visit_transform(
336336
AppliedPTransform(
337337
None, beam.GroupByKey(), "label", {'in': pcoll}, None, None))
338338

@@ -341,7 +341,7 @@ def test_group_by_key_input_visitor_for_non_gbk_transforms(self):
341341
pcoll = PCollection(p)
342342
for transform in [beam.Flatten(), beam.Map(lambda x: x)]:
343343
pcoll.element_type = typehints.Any
344-
common.group_by_key_input_visitor().visit_transform(
344+
pipeline_utils.group_by_key_input_visitor().visit_transform(
345345
AppliedPTransform(
346346
None, transform, "label", {'in': pcoll}, None, None))
347347
self.assertEqual(pcoll.element_type, typehints.Any)
@@ -383,7 +383,7 @@ def test_gbk_then_flatten_input_visitor(self):
383383
# to make sure the check below is not vacuous.
384384
self.assertNotIsInstance(flat.element_type, typehints.TupleConstraint)
385385

386-
p.visit(common.group_by_key_input_visitor())
386+
p.visit(pipeline_utils.group_by_key_input_visitor())
387387
p.visit(DataflowRunner.flatten_input_visitor())
388388

389389
# The dataflow runner requires gbk input to be tuples *and* flatten

sdks/python/apache_beam/runners/dataflow/internal/apiclient.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,10 @@
6363
from apache_beam.options.pipeline_options import WorkerOptions
6464
from apache_beam.portability import common_urns
6565
from apache_beam.portability.api import beam_runner_api_pb2
66-
from apache_beam.runners.common import validate_pipeline_graph
6766
from apache_beam.runners.dataflow.internal import names
6867
from apache_beam.runners.dataflow.internal.clients import dataflow
6968
from apache_beam.runners.internal import names as shared_names
69+
from apache_beam.runners.pipeline_utils import validate_pipeline_graph
7070
from apache_beam.runners.portability.stager import Stager
7171
from apache_beam.transforms import DataflowDistributionCounter
7272
from apache_beam.transforms import cy_combiners

0 commit comments

Comments
 (0)