Skip to content

Commit a833b2f

Browse files
committed
Add reducers for threading.Lock and EnumDescriptor.
1 parent 1cf9755 commit a833b2f

File tree

2 files changed

+78
-0
lines changed

2 files changed

+78
-0
lines changed

sdks/python/apache_beam/internal/cloudpickle_pickler.py

+63
Original file line numberDiff line numberDiff line change
@@ -30,20 +30,73 @@
3030
import base64
3131
import bz2
3232
import io
33+
import sys
3334
import threading
3435
import zlib
3536

37+
from google.protobuf.internal import api_implementation
38+
3639
from apache_beam.internal.cloudpickle import cloudpickle
3740

3841
try:
3942
from absl import flags
4043
except (ImportError, ModuleNotFoundError):
4144
pass
4245

46+
47+
def _get_proto_enum_descriptor_class():
48+
implementation_type = api_implementation.Type()
49+
50+
if implementation_type == 'upb':
51+
try:
52+
from google._upb._message import EnumDescriptor
53+
return EnumDescriptor
54+
except ImportError:
55+
pass
56+
elif implementation_type == 'cpp':
57+
try:
58+
from google.protobuf.pyext._message import EnumDescriptor
59+
return EnumDescriptor
60+
except ImportError:
61+
pass
62+
elif implementation_type == 'python':
63+
try:
64+
from google.protobuf.internal.python_message import EnumDescriptor
65+
return EnumDescriptor
66+
except ImportError:
67+
pass
68+
69+
return None
70+
71+
72+
EnumDescriptor = _get_proto_enum_descriptor_class()
73+
4374
# Pickling, especially unpickling, causes broken module imports on Python 3
4475
# if executed concurrently, see: BEAM-8651, http://bugs.python.org/issue38884.
4576
_pickle_lock = threading.RLock()
4677
RLOCK_TYPE = type(_pickle_lock)
78+
LOCK_TYPE = type(threading.Lock())
79+
80+
81+
def _reconstruct_enum_descriptor(full_name):
82+
for _, module in sys.modules.items():
83+
if not hasattr(module, 'DESCRIPTOR'):
84+
continue
85+
86+
for _, attr_value in vars(module).items():
87+
if not hasattr(attr_value, 'DESCRIPTOR'):
88+
continue
89+
90+
if hasattr(attr_value.DESCRIPTOR, 'enum_types_by_name'):
91+
for (_, enum_desc) in attr_value.DESCRIPTOR.enum_types_by_name.items():
92+
if enum_desc.full_name == full_name:
93+
return enum_desc
94+
raise ImportError(f'Could not find enum descriptor: {full_name}')
95+
96+
97+
def _pickle_enum_descriptor(obj):
98+
full_name = obj.full_name
99+
return _reconstruct_enum_descriptor, (full_name, )
47100

48101

49102
def dumps(o, enable_trace=True, use_zlib=False) -> bytes:
@@ -59,6 +112,12 @@ def dumps(o, enable_trace=True, use_zlib=False) -> bytes:
59112
pickler.dispatch_table[RLOCK_TYPE] = _pickle_rlock
60113
except NameError:
61114
pass
115+
try:
116+
pickler.dispatch_table[LOCK_TYPE] = _lock_reducer
117+
except NameError:
118+
pass
119+
if EnumDescriptor is not None:
120+
pickler.dispatch_table[EnumDescriptor] = _pickle_enum_descriptor
62121
pickler.dump(o)
63122
s = file.getvalue()
64123

@@ -106,6 +165,10 @@ def _pickle_rlock(obj):
106165
return RLOCK_TYPE, tuple([])
107166

108167

168+
def _lock_reducer(obj):
169+
return threading.Lock, tuple([])
170+
171+
109172
def dump_session(file_path):
110173
# It is possible to dump session with cloudpickle. However, since references
111174
# are saved it should not be necessary. See https://s.apache.org/beam-picklers

sdks/python/apache_beam/internal/cloudpickle_pickler_test.py

+15
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,21 @@
2626
from apache_beam.internal import module_test
2727
from apache_beam.internal.cloudpickle_pickler import dumps
2828
from apache_beam.internal.cloudpickle_pickler import loads
29+
from apache_beam.portability.api import beam_runner_api_pb2
2930

3031

3132
class PicklerTest(unittest.TestCase):
3233

3334
NO_MAPPINGPROXYTYPE = not hasattr(types, "MappingProxyType")
3435

36+
def test_pickle_enum_descriptor(self):
37+
TimeDomain = beam_runner_api_pb2.TimeDomain.Enum
38+
39+
def fn():
40+
return TimeDomain.EVENT_TIME
41+
42+
self.assertEqual(fn(), loads(dumps(fn))())
43+
3544
def test_basics(self):
3645
self.assertEqual([1, 'a', ('z', )], loads(dumps([1, 'a', ('z', )])))
3746
fun = lambda x: 'xyz-%s' % x
@@ -97,6 +106,12 @@ def test_pickle_rlock(self):
97106

98107
self.assertIsInstance(loads(dumps(rlock_instance)), rlock_type)
99108

109+
def test_pickle_lock(self):
110+
lock_instance = threading.Lock()
111+
lock_type = type(lock_instance)
112+
113+
self.assertIsInstance(loads(dumps(lock_instance)), lock_type)
114+
100115
@unittest.skipIf(NO_MAPPINGPROXYTYPE, 'test if MappingProxyType introduced')
101116
def test_dump_and_load_mapping_proxy(self):
102117
self.assertEqual(

0 commit comments

Comments
 (0)