Skip to content

Commit 3975e89

Browse files
committed
Add an experiment flag to make prism job server a singleton.
1 parent bd2891d commit 3975e89

File tree

2 files changed

+73
-1
lines changed

2 files changed

+73
-1
lines changed

Diff for: sdks/python/apache_beam/runners/portability/prism_runner.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import shutil
2929
import stat
3030
import subprocess
31+
import threading
3132
import typing
3233
import urllib
3334
import zipfile
@@ -56,6 +57,20 @@ class PrismRunner(portable_runner.PortableRunner):
5657
"""A runner for launching jobs on Prism, automatically downloading and
5758
starting a Prism instance if needed.
5859
"""
60+
__singleton = None
61+
__singleton_lock = threading.Lock()
62+
63+
@staticmethod
64+
def get_job_server(options):
65+
debug_options = options.view_as(pipeline_options.DebugOptions)
66+
if debug_options.lookup_experiment("enable_prism_server_singleton"):
67+
with PrismRunner.__singleton_lock:
68+
if PrismRunner.__singleton is None:
69+
PrismRunner.__singleton = PrismJobServer(options)
70+
return PrismRunner.__singleton
71+
72+
return PrismJobServer(options)
73+
5974
def default_environment(
6075
self,
6176
options: pipeline_options.PipelineOptions) -> environments.Environment:
@@ -66,7 +81,7 @@ def default_environment(
6681
return super().default_environment(options)
6782

6883
def default_job_server(self, options):
69-
return job_server.StopOnExitJobServer(PrismJobServer(options))
84+
return job_server.StopOnExitJobServer(PrismRunner.get_job_server(options))
7085

7186
def create_job_service_handle(self, job_service, options):
7287
return portable_runner.JobServiceHandle(
@@ -92,6 +107,22 @@ def __init__(self, options):
92107

93108
job_options = options.view_as(pipeline_options.JobServerOptions)
94109
self._job_port = job_options.job_port
110+
self._lock = threading.Lock()
111+
self._started = False
112+
self._endpoint = None
113+
114+
def start(self):
115+
with self._lock:
116+
if not self._started:
117+
self._endpoint = super().start()
118+
self._started = True
119+
return self._endpoint
120+
121+
def stop(self):
122+
with self._lock:
123+
if self._started:
124+
super().stop()
125+
self._started = False
95126

96127
@classmethod
97128
def maybe_unzip_and_make_executable(

Diff for: sdks/python/apache_beam/runners/portability/prism_runner_test.py

+41
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,47 @@ def test_with_remote_path(self, has_cache_bin, has_cache_zip, ignore_cache):
381381
mock_zipfile_init.assert_called_once()
382382

383383

384+
class PrismRunnerSingletonTest(unittest.TestCase):
385+
@parameterized.expand([True, False])
386+
def test_singleton(self, enable_singleton):
387+
if enable_singleton:
388+
options = DebugOptions(["--experiment=enable_prism_server_singleton"])
389+
else:
390+
options = DebugOptions()
391+
392+
with mock.patch(
393+
'apache_beam.runners.portability.job_server.subprocess_server.SubprocessServer.start' # pylint: disable=line-too-long
394+
) as mock_start:
395+
# Reset the class-level singleton for every fresh run
396+
prism_runner.PrismRunner._PrismRunner__singleton = None
397+
398+
try:
399+
with beam.Pipeline(options=options,
400+
runner=prism_runner.PrismRunner()) as p:
401+
_ = p | "Create Elements" >> beam.Create(
402+
range(5)) | "Squares" >> beam.Map(lambda x: x**2)
403+
except: # pylint: disable=bare-except
404+
pass
405+
406+
mock_start.assert_called_once()
407+
mock_start.reset_mock()
408+
409+
try:
410+
with beam.Pipeline(options=options,
411+
runner=prism_runner.PrismRunner()) as p:
412+
_ = p | "Create Elements" >> beam.Create(
413+
range(5)) | "Squares" >> beam.Map(lambda x: x**2)
414+
except: # pylint: disable=bare-except
415+
pass
416+
417+
if enable_singleton:
418+
# If singleton is enabled, we won't try to start a new server for the
419+
# second run.
420+
mock_start.assert_not_called()
421+
else:
422+
mock_start.assert_called_once()
423+
424+
384425
if __name__ == '__main__':
385426
# Run the tests.
386427
logging.getLogger().setLevel(logging.INFO)

0 commit comments

Comments
 (0)