Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sdks/go/pkg/beam/runners/prism/internal/environments.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ func externalEnvironment(ctx context.Context, ep *pipepb.ExternalPayload, wk *wo
pool.StopWorker(bgContext, &fnpb.StopWorkerRequest{
WorkerId: wk.ID,
})
// Allow a brief grace period for the SDK worker to cleanly shut down its client gRPC channels
// before tearing down the server-side gRPC streams.
time.Sleep(1 * time.Second)
wk.Stop()
}

Expand Down
13 changes: 13 additions & 0 deletions sdks/python/apache_beam/runners/job/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import json
import logging

import grpc
from google.protobuf import json_format
from google.protobuf import struct_pb2

Expand All @@ -37,3 +38,15 @@ def dict_to_struct(dict_obj: dict) -> struct_pb2.Struct:

def struct_to_dict(struct_obj: struct_pb2.Struct) -> dict:
return json.loads(json_format.MessageToJson(struct_obj))


def is_grpc_stream_closure_error(e, allow_deadline_exceeded=False):
"""Check whether a gRPC exception represents an expected stream termination
by the runner during job shutdown, cancellation, or timeout.
"""
if not isinstance(e, grpc.RpcError) or not hasattr(e, 'code'):
return False
expected_codes = {grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.CANCELLED}
if allow_deadline_exceeded:
expected_codes.add(grpc.StatusCode.DEADLINE_EXCEEDED)
return e.code() in expected_codes
33 changes: 20 additions & 13 deletions sdks/python/apache_beam/runners/portability/portable_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,20 +534,27 @@ def wait_until_finish(self, duration=None):
def read_messages() -> None:
nonlocal last_error_text
previous_state = -1
for message in self._message_stream:
if message.HasField('message_response'):
mr = message.message_response
logging.log(MESSAGE_LOG_LEVELS[mr.importance], "%s", mr.message_text)
if mr.importance == beam_job_api_pb2.JobMessage.JOB_MESSAGE_ERROR:
last_error_text = mr.message_text
try:
for message in self._message_stream:
if message.HasField('message_response'):
mr = message.message_response
logging.log(
MESSAGE_LOG_LEVELS[mr.importance], "%s", mr.message_text)
if mr.importance == beam_job_api_pb2.JobMessage.JOB_MESSAGE_ERROR:
last_error_text = mr.message_text
else:
current_state = message.state_response.state
if current_state != previous_state:
_LOGGER.info(
"Job state changed to %s",
self.runner_api_state_to_pipeline_state(current_state))
previous_state = current_state
self._messages.append(message)
except grpc.RpcError as e:
if job_utils.is_grpc_stream_closure_error(e):
_LOGGER.info('Job message stream closed by runner: %s', e)
else:
current_state = message.state_response.state
if current_state != previous_state:
_LOGGER.info(
"Job state changed to %s",
self.runner_api_state_to_pipeline_state(current_state))
previous_state = current_state
self._messages.append(message)
raise

message_thread = threading.Thread(
target=read_messages, name='wait_until_finish_read')
Expand Down
87 changes: 86 additions & 1 deletion sdks/python/apache_beam/runners/portability/prism_runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import os.path
import queue
import shlex
import threading
import time
import typing
import unittest
Expand Down Expand Up @@ -300,6 +299,92 @@ def test_after_count_trigger_streaming(self):
('B-3', {10, 15, 16}),
])))

def test_dofn_failure_clean_exit(self):
class FailDoFn(beam.DoFn):
def process(self, element):
raise ValueError("Failing as intended")

class BlockDoFn(beam.DoFn):
def process(self, element):
time.sleep(1000)
yield element

with self.assertRaisesRegex(Exception, "Failing as intended"):
with self.create_pipeline() as p:
imp = p | beam.Create([1, 2])
# Ensure the steps are not fused (otherwise siblings are run sequentially
# in a single thread, making execution order dependent on internal map
# traversal). Reshuffle acts as a fusion break so they run in parallel.
_ = imp | 'ReshuffleBlock' >> beam.Reshuffle() | 'Block' >> beam.ParDo(
BlockDoFn())
_ = imp | 'ReshuffleFail' >> beam.Reshuffle() | 'Fail' >> beam.ParDo(
FailDoFn())

def test_dofn_failure_delayed_worker_shutdown(self):
"""Simulates a scenario where a DoFn failure causes pipeline abortion, but the Python
SDK worker takes a longer time to shut down before Prism closes the gRPC channel.
Verifies both Option 1 (graceful handling of RpcError) and Option 2 (Prism runner grace period).
"""
from apache_beam.runners.worker.data_plane import GrpcClientDataChannel

orig_close = GrpcClientDataChannel.close

def delayed_close(self_channel):
time.sleep(0.5)
orig_close(self_channel)

class FailDoFn(beam.DoFn):
def process(self, element):
raise ValueError("Delayed shutdown fail as intended")

class BlockDoFn(beam.DoFn):
def process(self, element):
time.sleep(1000)
yield element

with mock.patch.object(GrpcClientDataChannel, 'close', new=delayed_close):
with self.assertLogs('apache_beam.runners.worker.data_plane',
level='DEBUG') as log_cm:
with self.assertRaisesRegex(Exception,
"Delayed shutdown fail as intended"):
with self.create_pipeline() as p:
imp = p | beam.Create([1, 2])
_ = imp | 'ReshuffleBlock' >> beam.Reshuffle(
) | 'Block' >> beam.ParDo(BlockDoFn())
_ = imp | 'ReshuffleFail' >> beam.Reshuffle(
) | 'Fail' >> beam.ParDo(FailDoFn())

# Ensure no ERROR logs were emitted by data_plane during the delayed shutdown
self.assertFalse(
any(
"Failed to read inputs in the data plane." in log
for log in log_cm.output))

def test_dofn_deadline_exceeded(self):
"""Simulates a scenario where a pipeline running on Prism exceeds its configured
deadline, triggering DEADLINE_EXCEEDED on the job message, data, and control streams.
Verifies that all stream closures are handled cleanly without unhandled thread crashes.
"""
from apache_beam.runners.portability.portable_runner import JobServiceHandle

orig_init = JobServiceHandle.__init__

def custom_init(
self_handle, job_service, options, retain_unknown_options=False):
orig_init(self_handle, job_service, options, retain_unknown_options)
self_handle.timeout = 2

class BlockDoFn(beam.DoFn):
def process(self, element):
time.sleep(1000)
yield element

with mock.patch.object(JobServiceHandle, '__init__', new=custom_init):
with self.assertRaisesRegex(Exception, "Deadline Exceeded"):
with self.create_pipeline() as p:
_ = p | beam.Create([1, 2]) | beam.Reshuffle() | beam.ParDo(
BlockDoFn())


class PrismJobServerTest(unittest.TestCase):
def setUp(self) -> None:
Expand Down
10 changes: 7 additions & 3 deletions sdks/python/apache_beam/runners/worker/data_plane.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from apache_beam.coders import coder_impl
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import beam_fn_api_pb2_grpc
from apache_beam.runners.job import utils as job_utils
from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor
from apache_beam.utils.byte_limited_queue import ByteLimitedQueue
Expand Down Expand Up @@ -723,9 +724,12 @@ def _put_queue(instruction_id, element):
_put_queue(data.instruction_id, data)
except Exception as e:
if not self._closed:
_LOGGER.exception('Failed to read inputs in the data plane.')
self._exception = e
raise
if job_utils.is_grpc_stream_closure_error(e):
_LOGGER.info('Data plane stream closed by runner: %s', e)
else:
_LOGGER.exception('Failed to read inputs in the data plane.')
self._exception = e
raise
finally:
self._closed = True
self._reads_finished.set()
Expand Down
6 changes: 6 additions & 0 deletions sdks/python/apache_beam/runners/worker/sdk_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from apache_beam.portability.api import beam_fn_api_pb2
from apache_beam.portability.api import beam_fn_api_pb2_grpc
from apache_beam.portability.api import metrics_pb2
from apache_beam.runners.job import utils as job_utils
from apache_beam.runners.worker import bundle_processor
from apache_beam.runners.worker import data_plane
from apache_beam.runners.worker import statesampler
Expand Down Expand Up @@ -279,6 +280,11 @@ def get_responses():
# will be like self.request_register(request)
getattr(self, SdkHarness.REQUEST_METHOD_PREFIX + request_type)(
work_request)
except grpc.RpcError as e:
if job_utils.is_grpc_stream_closure_error(e):
_LOGGER.info('Control plane stream closed by runner: %s', e)
else:
raise
finally:
self._alive = False
if self.data_sampler:
Expand Down
23 changes: 20 additions & 3 deletions sdks/python/apache_beam/utils/multi_process_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ def __init__(
self._cross_process_lock = fasteners.InterProcessLock(
os.path.join(self._path, self._tag) + '.lock')
self._spawn_process = spawn_process
self._server_process = None

def _get_manager(self):
if self._manager is None:
Expand Down Expand Up @@ -397,7 +398,7 @@ def _get_manager(self):
retryable_exceptions = (ConnectionError, EOFError)

@retry.with_exponential_backoff(
num_retries=5,
num_retries=7,
initial_delay_secs=0.1,
retry_filter=lambda exn: isinstance(
exn, retryable_exceptions))
Expand All @@ -408,8 +409,23 @@ def connect_manager():
connect_manager()
self._manager = manager
except retryable_exceptions:
# The server is no longer good, assume it died.
os.unlink(address_file)
# The server is no longer good, terminate it if we spawned it.
if getattr(self, '_life_line', None):
try:
self._life_line.close()
except Exception:
pass
if getattr(self, '_server_process', None) and self._server_process.is_alive():
logging.info(
"Terminating unresponsive server process %s",
self._server_process.pid)
try:
self._server_process.kill()
self._server_process.join(timeout=1.0)
except Exception:
pass
if os.path.exists(address_file):
os.unlink(address_file)

return self._manager

Expand Down Expand Up @@ -463,6 +479,7 @@ def _create_server(self, address_file):
daemon=False # Must be False for nested proxies
)
p.start()
self._server_process = p
logging.info("Parent: Waiting for %s to write address file...", self._tag)

def cleanup_process():
Expand Down
39 changes: 38 additions & 1 deletion sdks/python/apache_beam/utils/multi_process_shared_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,8 @@ def setUp(self):
'mix2',
'test_process_exit',
'thundering_herd_test',
'transient_test']:
'transient_test',
'timeout_deadlock_test']:
for ext in ['', '.address', '.address.error']:
try:
os.remove(os.path.join(tempdir, tag + ext))
Expand Down Expand Up @@ -489,6 +490,42 @@ def side_effect_connect(self_mgr, *args, **kwargs):
self.assertEqual(counter1.increment(), 1)
self.assertEqual(counter2.increment(), 2)

def test_connection_timeout_respawn_deadlock(self):
# Tests that when connecting to a newly spawned proxy server times out after all retries,
# the parent process cleanly terminates the unresponsive server before spawning a replacement.
# Without this termination, the parent process unlinks the address file and spawns Server 2,
# discarding the IPC pipe to Server 1. When Server 1 detects the closed pipe, its suicide
# pact monitor wakes up and deletes Server 2's address file, causing the parent process to
# hang forever in _create_server waiting for the address file to appear.
shared = multi_process_shared.MultiProcessShared(
Counter, tag='timeout_deadlock_test', always_proxy=True, spawn_process=True)

orig_connect = multi_process_shared._SingletonRegistrar.connect
connect_calls = [0]

def side_effect_connect(self_mgr, *args, **kwargs):
connect_calls[0] += 1
if connect_calls[0] <= 10:
raise ConnectionError("Simulated proxy server connection timeout")
return orig_connect(self_mgr, *args, **kwargs)

with patch.object(
multi_process_shared._SingletonRegistrar,
'connect',
autospec=True,
side_effect=side_effect_connect):
res = []
def run_acquire():
res.append(shared.acquire())

t = threading.Thread(target=run_acquire)
t.start()
t.join(timeout=20.0)
if t.is_alive():
raise TimeoutError("Timeout waiting for proxy server to respawn")

self.assertEqual(res[0].increment(), 1)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
Loading