Skip to content

More eagerly merge compatible environments. #34583

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sdks/python/apache_beam/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@
from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners import PipelineRunner
from apache_beam.runners import common
from apache_beam.runners import create_runner
from apache_beam.runners import pipeline_utils
from apache_beam.transforms import ParDo
from apache_beam.transforms import ptransform
from apache_beam.transforms.display import DisplayData
Expand Down Expand Up @@ -1019,7 +1019,7 @@ def merge_compatible_environments(proto):

Mutates proto as contexts may have references to proto.components.
"""
common.merge_common_environments(proto, inplace=True)
pipeline_utils.merge_common_environments(proto, inplace=True)

@staticmethod
def from_runner_api(
Expand Down
162 changes: 0 additions & 162 deletions sdks/python/apache_beam/runners/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@

# pytype: skip-file

import collections
import copy
import logging
import sys
import threading
Expand All @@ -42,8 +40,6 @@
from apache_beam.coders import coders
from apache_beam.internal import util
from apache_beam.options.value_provider import RuntimeValueProvider
from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.pvalue import TaggedOutput
from apache_beam.runners.sdf_utils import NoOpWatermarkEstimatorProvider
from apache_beam.runners.sdf_utils import RestrictionTrackerView
Expand All @@ -53,15 +49,13 @@
from apache_beam.runners.sdf_utils import ThreadsafeWatermarkEstimator
from apache_beam.transforms import DoFn
from apache_beam.transforms import core
from apache_beam.transforms import environments
from apache_beam.transforms import userstate
from apache_beam.transforms.core import RestrictionProvider
from apache_beam.transforms.core import WatermarkEstimatorProvider
from apache_beam.transforms.window import GlobalWindow
from apache_beam.transforms.window import GlobalWindows
from apache_beam.transforms.window import TimestampedValue
from apache_beam.transforms.window import WindowFn
from apache_beam.typehints import typehints
from apache_beam.typehints.batch import BatchConverter
from apache_beam.utils.counters import Counter
from apache_beam.utils.counters import CounterName
Expand Down Expand Up @@ -1920,159 +1914,3 @@ def windows(self):
raise AttributeError('windows not accessible in this context')
else:
return self.windowed_value.windows


def group_by_key_input_visitor(deterministic_key_coders=True):
# Importing here to avoid a circular dependency
# pylint: disable=wrong-import-order, wrong-import-position
from apache_beam.pipeline import PipelineVisitor
from apache_beam.transforms.core import GroupByKey

class GroupByKeyInputVisitor(PipelineVisitor):
"""A visitor that replaces `Any` element type for input `PCollection` of
a `GroupByKey` with a `KV` type.

TODO(BEAM-115): Once Python SDK is compatible with the new Runner API,
we could directly replace the coder instead of mutating the element type.
"""
def __init__(self, deterministic_key_coders=True):
self.deterministic_key_coders = deterministic_key_coders

def enter_composite_transform(self, transform_node):
self.visit_transform(transform_node)

def visit_transform(self, transform_node):
if isinstance(transform_node.transform, GroupByKey):
pcoll = transform_node.inputs[0]
pcoll.element_type = typehints.coerce_to_kv_type(
pcoll.element_type, transform_node.full_label)
pcoll.requires_deterministic_key_coder = (
self.deterministic_key_coders and transform_node.full_label)
key_type, value_type = pcoll.element_type.tuple_types
if transform_node.outputs:
key = next(iter(transform_node.outputs.keys()))
transform_node.outputs[key].element_type = typehints.KV[
key_type, typehints.Iterable[value_type]]
transform_node.outputs[key].requires_deterministic_key_coder = (
self.deterministic_key_coders and transform_node.full_label)

return GroupByKeyInputVisitor(deterministic_key_coders)


def validate_pipeline_graph(pipeline_proto):
"""Ensures this is a correctly constructed Beam pipeline.
"""
def get_coder(pcoll_id):
return pipeline_proto.components.coders[
pipeline_proto.components.pcollections[pcoll_id].coder_id]

def validate_transform(transform_id):
transform_proto = pipeline_proto.components.transforms[transform_id]

# Currently the only validation we perform is that GBK operations have
# their coders set properly.
if transform_proto.spec.urn == common_urns.primitives.GROUP_BY_KEY.urn:
if len(transform_proto.inputs) != 1:
raise ValueError("Unexpected number of inputs: %s" % transform_proto)
if len(transform_proto.outputs) != 1:
raise ValueError("Unexpected number of outputs: %s" % transform_proto)
input_coder = get_coder(next(iter(transform_proto.inputs.values())))
output_coder = get_coder(next(iter(transform_proto.outputs.values())))
if input_coder.spec.urn != common_urns.coders.KV.urn:
raise ValueError(
"Bad coder for input of %s: %s" % (transform_id, input_coder))
if output_coder.spec.urn != common_urns.coders.KV.urn:
raise ValueError(
"Bad coder for output of %s: %s" % (transform_id, output_coder))
output_values_coder = pipeline_proto.components.coders[
output_coder.component_coder_ids[1]]
if (input_coder.component_coder_ids[0] !=
output_coder.component_coder_ids[0] or
output_values_coder.spec.urn != common_urns.coders.ITERABLE.urn or
output_values_coder.component_coder_ids[0] !=
input_coder.component_coder_ids[1]):
raise ValueError(
"Incompatible input coder %s and output coder %s for transform %s" %
(transform_id, input_coder, output_coder))
elif transform_proto.spec.urn == common_urns.primitives.ASSIGN_WINDOWS.urn:
if not transform_proto.inputs:
raise ValueError("Missing input for transform: %s" % transform_proto)
elif transform_proto.spec.urn == common_urns.primitives.PAR_DO.urn:
if not transform_proto.inputs:
raise ValueError("Missing input for transform: %s" % transform_proto)

for t in transform_proto.subtransforms:
validate_transform(t)

for t in pipeline_proto.root_transform_ids:
validate_transform(t)


def merge_common_environments(pipeline_proto, inplace=False):
def dep_key(dep):
if dep.type_urn == common_urns.artifact_types.FILE.urn:
payload = beam_runner_api_pb2.ArtifactFilePayload.FromString(
dep.type_payload)
if payload.sha256:
type_info = 'sha256', payload.sha256
else:
type_info = 'path', payload.path
elif dep.type_urn == common_urns.artifact_types.URL.urn:
payload = beam_runner_api_pb2.ArtifactUrlPayload.FromString(
dep.type_payload)
if payload.sha256:
type_info = 'sha256', payload.sha256
else:
type_info = 'url', payload.url
else:
type_info = dep.type_urn, dep.type_payload
return type_info, dep.role_urn, dep.role_payload

def base_env_key(env):
return (
env.urn,
env.payload,
tuple(sorted(env.capabilities)),
tuple(sorted(env.resource_hints.items())),
tuple(sorted(dep_key(dep) for dep in env.dependencies)))

def env_key(env):
return tuple(
sorted(
base_env_key(e)
for e in environments.expand_anyof_environments(env)))

canonical_environments = collections.defaultdict(list)
for env_id, env in pipeline_proto.components.environments.items():
canonical_environments[env_key(env)].append(env_id)

if len(canonical_environments) == len(pipeline_proto.components.environments):
# All environments are already sufficiently distinct.
return pipeline_proto

environment_remappings = {
e: es[0]
for es in canonical_environments.values() for e in es
}

if not inplace:
pipeline_proto = copy.copy(pipeline_proto)

for t in pipeline_proto.components.transforms.values():
if t.environment_id not in pipeline_proto.components.environments:
# TODO(https://github.com/apache/beam/issues/30876): Remove this
# workaround.
continue
if t.environment_id:
t.environment_id = environment_remappings[t.environment_id]
for w in pipeline_proto.components.windowing_strategies.values():
if w.environment_id not in pipeline_proto.components.environments:
# TODO(https://github.com/apache/beam/issues/30876): Remove this
# workaround.
continue
if w.environment_id:
w.environment_id = environment_remappings[w.environment_id]
for e in set(pipeline_proto.components.environments.keys()) - set(
environment_remappings.values()):
del pipeline_proto.components.environments[e]
return pipeline_proto
59 changes: 0 additions & 59 deletions sdks/python/apache_beam/runners/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,8 @@
from apache_beam.io.restriction_trackers import OffsetRestrictionTracker
from apache_beam.io.watermark_estimators import ManualWatermarkEstimator
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners.common import DoFnSignature
from apache_beam.runners.common import PerWindowInvoker
from apache_beam.runners.common import merge_common_environments
from apache_beam.runners.portability.expansion_service_test import FibTransform
from apache_beam.runners.sdf_utils import SplitResultPrimary
from apache_beam.runners.sdf_utils import SplitResultResidual
from apache_beam.testing.test_pipeline import TestPipeline
Expand Down Expand Up @@ -587,61 +584,5 @@ def test_window_observing_split_on_window_boundary_round_down_on_last_window(
self.assertEqual(stop_index, 2)


class UtilitiesTest(unittest.TestCase):
def test_equal_environments_merged(self):
pipeline_proto = merge_common_environments(
beam_runner_api_pb2.Pipeline(
components=beam_runner_api_pb2.Components(
environments={
'a1': beam_runner_api_pb2.Environment(urn='A'),
'a2': beam_runner_api_pb2.Environment(urn='A'),
'b1': beam_runner_api_pb2.Environment(
urn='B', payload=b'x'),
'b2': beam_runner_api_pb2.Environment(
urn='B', payload=b'x'),
'b3': beam_runner_api_pb2.Environment(
urn='B', payload=b'y'),
},
transforms={
't1': beam_runner_api_pb2.PTransform(
unique_name='t1', environment_id='a1'),
't2': beam_runner_api_pb2.PTransform(
unique_name='t2', environment_id='a2'),
},
windowing_strategies={
'w1': beam_runner_api_pb2.WindowingStrategy(
environment_id='b1'),
'w2': beam_runner_api_pb2.WindowingStrategy(
environment_id='b2'),
})))
self.assertEqual(len(pipeline_proto.components.environments), 3)
self.assertTrue(('a1' in pipeline_proto.components.environments)
^ ('a2' in pipeline_proto.components.environments))
self.assertTrue(('b1' in pipeline_proto.components.environments)
^ ('b2' in pipeline_proto.components.environments))
self.assertEqual(
len(
set(
t.environment_id
for t in pipeline_proto.components.transforms.values())),
1)
self.assertEqual(
len(
set(
w.environment_id for w in
pipeline_proto.components.windowing_strategies.values())),
1)

def test_external_merged(self):
p = beam.Pipeline()
# This transform recursively creates several external environments.
_ = p | FibTransform(4)
pipeline_proto = p.to_runner_api()
# All our external environments are equal and consolidated.
# We also have a placeholder "default" environment that has not been
# resolved do anything concrete yet.
self.assertEqual(len(pipeline_proto.components.environments), 2)


if __name__ == '__main__':
unittest.main()
8 changes: 5 additions & 3 deletions sdks/python/apache_beam/runners/dataflow/dataflow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,10 @@
from apache_beam.options.pipeline_options import WorkerOptions
from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners.common import group_by_key_input_visitor
from apache_beam.runners.common import merge_common_environments
from apache_beam.runners.dataflow.internal.clients import dataflow as dataflow_api
from apache_beam.runners.pipeline_utils import group_by_key_input_visitor
from apache_beam.runners.pipeline_utils import merge_common_environments
from apache_beam.runners.pipeline_utils import merge_superset_dep_environments
from apache_beam.runners.runner import PipelineResult
from apache_beam.runners.runner import PipelineRunner
from apache_beam.runners.runner import PipelineState
Expand Down Expand Up @@ -434,7 +435,8 @@ def run_pipeline(self, pipeline, options, pipeline_proto=None):
self.proto_pipeline.components.environments[env_id].CopyFrom(
environments.resolve_anyof_environment(
env, common_urns.environments.DOCKER.urn))
self.proto_pipeline = merge_common_environments(self.proto_pipeline)
self.proto_pipeline = merge_common_environments(
merge_superset_dep_environments(self.proto_pipeline))

# Optimize the pipeline if it not streaming and the pre_optimize
# experiment is set.
Expand Down
10 changes: 5 additions & 5 deletions sdks/python/apache_beam/runners/dataflow/dataflow_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
from apache_beam.pvalue import PCollection
from apache_beam.runners import DataflowRunner
from apache_beam.runners import TestDataflowRunner
from apache_beam.runners import common
from apache_beam.runners import create_runner
from apache_beam.runners import pipeline_utils
from apache_beam.runners.dataflow.dataflow_runner import DataflowPipelineResult
from apache_beam.runners.dataflow.dataflow_runner import DataflowRuntimeException
from apache_beam.runners.dataflow.dataflow_runner import _check_and_add_missing_options
Expand Down Expand Up @@ -316,7 +316,7 @@ def test_group_by_key_input_visitor_with_valid_inputs(self):
applied = AppliedPTransform(
None, beam.GroupByKey(), "label", {'pcoll': pcoll}, None, None)
applied.outputs[None] = PCollection(None)
common.group_by_key_input_visitor().visit_transform(applied)
pipeline_utils.group_by_key_input_visitor().visit_transform(applied)
self.assertEqual(
pcoll.element_type, typehints.KV[typehints.Any, typehints.Any])

Expand All @@ -332,7 +332,7 @@ def test_group_by_key_input_visitor_with_invalid_inputs(self):
"Found .*")
for pcoll in [pcoll1, pcoll2]:
with self.assertRaisesRegex(ValueError, err_msg):
common.group_by_key_input_visitor().visit_transform(
pipeline_utils.group_by_key_input_visitor().visit_transform(
AppliedPTransform(
None, beam.GroupByKey(), "label", {'in': pcoll}, None, None))

Expand All @@ -341,7 +341,7 @@ def test_group_by_key_input_visitor_for_non_gbk_transforms(self):
pcoll = PCollection(p)
for transform in [beam.Flatten(), beam.Map(lambda x: x)]:
pcoll.element_type = typehints.Any
common.group_by_key_input_visitor().visit_transform(
pipeline_utils.group_by_key_input_visitor().visit_transform(
AppliedPTransform(
None, transform, "label", {'in': pcoll}, None, None))
self.assertEqual(pcoll.element_type, typehints.Any)
Expand Down Expand Up @@ -383,7 +383,7 @@ def test_gbk_then_flatten_input_visitor(self):
# to make sure the check below is not vacuous.
self.assertNotIsInstance(flat.element_type, typehints.TupleConstraint)

p.visit(common.group_by_key_input_visitor())
p.visit(pipeline_utils.group_by_key_input_visitor())
p.visit(DataflowRunner.flatten_input_visitor())

# The dataflow runner requires gbk input to be tuples *and* flatten
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@
from apache_beam.options.pipeline_options import WorkerOptions
from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners.common import validate_pipeline_graph
from apache_beam.runners.dataflow.internal import names
from apache_beam.runners.dataflow.internal.clients import dataflow
from apache_beam.runners.internal import names as shared_names
from apache_beam.runners.pipeline_utils import validate_pipeline_graph
from apache_beam.runners.portability.stager import Stager
from apache_beam.transforms import DataflowDistributionCounter
from apache_beam.transforms import cy_combiners
Expand Down
Loading
Loading