Skip to content

Commit ae7bf20

Browse files
authored
Add option to pickler dumps() for best-effort determinism (#34698)
* Add option to pickler dumps() for best-effort determinism The motivation for this change is Flume caches pickled code and non-determinism breaks the caching. While making pickling fully-determinism is infeasible, increasing the determinism is still useful due to the increase in cache hits. Sets are a common source of non-determinism. This change sorts set elements to provide deterministic serialization. Because not all types provide a comparison operator, the sorting routine implements generic element comparison logic. See: #34410 * Fix linter errors * Restore previous pickler settings and improve determinism test cases * Restore pickler settings in finally block * Plumb feature flag from runner to to_runner_api_pickled The FlumeRunner will enable this feature flag via its canary mechansim. The option is localized to only serialization of transforms because that is what the FlumeRunner wants to cache. It will not enable best-effort deterministic serialization for other uses of pickling.
1 parent d3b526a commit ae7bf20

File tree

10 files changed

+542
-6
lines changed

10 files changed

+542
-6
lines changed

sdks/python/apache_beam/internal/cloudpickle_pickler.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,16 @@ def _pickle_enum_descriptor(obj):
107107
return _reconstruct_enum_descriptor, (full_name, )
108108

109109

110-
def dumps(o, enable_trace=True, use_zlib=False) -> bytes:
110+
def dumps(
111+
o,
112+
enable_trace=True,
113+
use_zlib=False,
114+
enable_best_effort_determinism=False) -> bytes:
111115
"""For internal use only; no backwards-compatibility guarantees."""
116+
if enable_best_effort_determinism:
117+
# TODO: Add support once https://github.com/cloudpipe/cloudpickle/pull/563
118+
# is merged in.
119+
raise NotImplementedError('This option has only been implemeneted for dill')
112120
with _pickle_lock:
113121
with io.BytesIO() as file:
114122
pickler = cloudpickle.CloudPickler(file)

sdks/python/apache_beam/internal/cloudpickle_pickler_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,10 @@ def test_dataclass(self):
212212
self.assertEqual(DataClass(datum='abc'), loads(dumps(DataClass(datum='abc'))))
213213
''')
214214

215+
def test_best_effort_determinism_not_implemented(self):
216+
with self.assertRaises(NotImplementedError):
217+
dumps(123, enable_best_effort_determinism=True)
218+
215219

216220
if __name__ == '__main__':
217221
unittest.main()

sdks/python/apache_beam/internal/dill_pickler.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@
4444

4545
import dill
4646

47+
from apache_beam.internal.set_pickler import save_frozenset
48+
from apache_beam.internal.set_pickler import save_set
49+
4750
settings = {'dill_byref': None}
4851

4952
patch_save_code = sys.version_info >= (3, 10) and dill.__version__ == "0.3.1.1"
@@ -376,9 +379,18 @@ def new_log_info(msg, *args, **kwargs):
376379
logging.getLogger('dill').setLevel(logging.WARN)
377380

378381

379-
def dumps(o, enable_trace=True, use_zlib=False) -> bytes:
382+
def dumps(
383+
o,
384+
enable_trace=True,
385+
use_zlib=False,
386+
enable_best_effort_determinism=False) -> bytes:
380387
"""For internal use only; no backwards-compatibility guarantees."""
381388
with _pickle_lock:
389+
if enable_best_effort_determinism:
390+
old_save_set = dill.dill.Pickler.dispatch[set]
391+
old_save_frozenset = dill.dill.Pickler.dispatch[frozenset]
392+
dill.dill.pickle(set, save_set)
393+
dill.dill.pickle(frozenset, save_frozenset)
382394
try:
383395
s = dill.dumps(o, byref=settings['dill_byref'])
384396
except Exception: # pylint: disable=broad-except
@@ -389,6 +401,9 @@ def dumps(o, enable_trace=True, use_zlib=False) -> bytes:
389401
raise
390402
finally:
391403
dill.dill._trace(False) # pylint: disable=protected-access
404+
if enable_best_effort_determinism:
405+
dill.dill.pickle(set, old_save_set)
406+
dill.dill.pickle(frozenset, old_save_frozenset)
392407

393408
# Compress as compactly as possible (compresslevel=9) to decrease peak memory
394409
# usage (of multiple in-memory copies) and to avoid hitting protocol buffer

sdks/python/apache_beam/internal/pickler.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,17 @@
3838
desired_pickle_lib = dill_pickler
3939

4040

41-
def dumps(o, enable_trace=True, use_zlib=False) -> bytes:
41+
def dumps(
42+
o,
43+
enable_trace=True,
44+
use_zlib=False,
45+
enable_best_effort_determinism=False) -> bytes:
4246

4347
return desired_pickle_lib.dumps(
44-
o, enable_trace=enable_trace, use_zlib=use_zlib)
48+
o,
49+
enable_trace=enable_trace,
50+
use_zlib=use_zlib,
51+
enable_best_effort_determinism=enable_best_effort_determinism)
4552

4653

4754
def loads(encoded, enable_trace=True, use_zlib=False):

sdks/python/apache_beam/internal/pickler_test.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
# pytype: skip-file
2121

22+
import random
2223
import sys
2324
import threading
2425
import types
@@ -115,6 +116,52 @@ def test_dataclass(self):
115116
self.assertEqual(DataClass(datum='abc'), loads(dumps(DataClass(datum='abc'))))
116117
''')
117118

119+
def maybe_get_sets_with_different_iteration_orders(self):
120+
# Use a mix of types in an attempt to create sets with the same elements
121+
# whose iteration order is different.
122+
elements = [
123+
100,
124+
'hello',
125+
3.14159,
126+
True,
127+
None,
128+
-50,
129+
'world',
130+
False, (1, 2), (4, 3), ('hello', 'world')
131+
]
132+
set1 = set(elements)
133+
# Try random addition orders until finding an order that works.
134+
for _ in range(100):
135+
set2 = set()
136+
random.shuffle(elements)
137+
for e in elements:
138+
set2.add(e)
139+
if list(set1) != list(set2):
140+
break
141+
return set1, set2
142+
143+
def test_best_effort_determinism(self):
144+
set1, set2 = self.maybe_get_sets_with_different_iteration_orders()
145+
self.assertEqual(
146+
dumps(set1, enable_best_effort_determinism=True),
147+
dumps(set2, enable_best_effort_determinism=True))
148+
# The test relies on the sets having different iteration orders for the
149+
# elements. Iteration order is implementation dependent and undefined,
150+
# meaning the test won't always be able to setup these conditions.
151+
if list(set1) == list(set2):
152+
self.skipTest('Set iteration orders matched. Test results inconclusive.')
153+
154+
def test_disable_best_effort_determinism(self):
155+
set1, set2 = self.maybe_get_sets_with_different_iteration_orders()
156+
# The test relies on the sets having different iteration orders for the
157+
# elements. Iteration order is implementation dependent and undefined,
158+
# meaning the test won't always be able to setup these conditions.
159+
if list(set1) == list(set2):
160+
self.skipTest('Set iteration orders matched. Unable to complete test.')
161+
self.assertNotEqual(
162+
dumps(set1, enable_best_effort_determinism=False),
163+
dumps(set2, enable_best_effort_determinism=False))
164+
118165

119166
if __name__ == '__main__':
120167
unittest.main()
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
"""Custom pickling logic for sets to make the serialization semi-deterministic.
19+
20+
To make set serialization semi-deterministic, we must pick an order for the set
21+
elements. Sets may contain elements of types not defining a comparison "<"
22+
operator. To provide an order, we define our own custom comparison function
23+
which supports elements of near-arbitrary types and use that to sort the
24+
contents of each set during serialization. Attempts at determinism are made on a
25+
best-effort basis to improve hit rates for cached workflows and the ordering
26+
does not define a total order for all values.
27+
"""
28+
29+
import enum
30+
import functools
31+
32+
33+
def compare(lhs, rhs):
34+
"""Returns -1, 0, or 1 depending on whether lhs <, =, or > rhs."""
35+
if lhs < rhs:
36+
return -1
37+
elif lhs > rhs:
38+
return 1
39+
else:
40+
return 0
41+
42+
43+
def generic_object_comparison(lhs, rhs, lhs_path, rhs_path, max_depth):
44+
"""Identifies which object goes first in an (almost) total order of objects.
45+
46+
Args:
47+
lhs: An arbitrary Python object or built-in type.
48+
rhs: An arbitrary Python object or built-in type.
49+
lhs_path: Traversal path from the root lhs object up to, but not including,
50+
lhs. The original contents of lhs_path are restored before the function
51+
returns.
52+
rhs_path: Same as lhs_path except for the rhs.
53+
max_depth: Maximum recursion depth.
54+
55+
Returns:
56+
-1, 0, or 1 depending on whether lhs or rhs goes first in the total order.
57+
0 if max_depth is exhausted.
58+
0 if lhs is in lhs_path or rhs is in rhs_path (there is a cycle).
59+
"""
60+
if id(lhs) == id(rhs):
61+
# Fast path
62+
return 0
63+
if type(lhs) != type(rhs):
64+
return compare(str(type(lhs)), str(type(rhs)))
65+
if type(lhs) in [int, float, bool, str, bool, bytes, bytearray]:
66+
return compare(lhs, rhs)
67+
if isinstance(lhs, enum.Enum):
68+
# Enums can have values with arbitrary types. The names are strings.
69+
return compare(lhs.name, rhs.name)
70+
71+
# To avoid exceeding the recursion depth limit, set a limit on recursion.
72+
max_depth -= 1
73+
if max_depth < 0:
74+
return 0
75+
76+
# Check for cycles in the traversal path to avoid getting stuck in a loop.
77+
if id(lhs) in lhs_path or id(rhs) in rhs_path:
78+
return 0
79+
lhs_path.append(id(lhs))
80+
rhs_path.append(id(rhs))
81+
# The comparison logic is split across two functions to simplifying updating
82+
# and restoring the traversal paths.
83+
result = _generic_object_comparison_recursive_path(
84+
lhs, rhs, lhs_path, rhs_path, max_depth)
85+
lhs_path.pop()
86+
rhs_path.pop()
87+
return result
88+
89+
90+
def _generic_object_comparison_recursive_path(
91+
lhs, rhs, lhs_path, rhs_path, max_depth):
92+
if type(lhs) == tuple or type(lhs) == list:
93+
result = compare(len(lhs), len(rhs))
94+
if result != 0:
95+
return result
96+
for i in range(len(lhs)):
97+
result = generic_object_comparison(
98+
lhs[i], rhs[i], lhs_path, rhs_path, max_depth)
99+
if result != 0:
100+
return result
101+
return 0
102+
if type(lhs) == frozenset or type(lhs) == set:
103+
return generic_object_comparison(
104+
tuple(sort_if_possible(lhs, lhs_path, rhs_path, max_depth)),
105+
tuple(sort_if_possible(rhs, lhs_path, rhs_path, max_depth)),
106+
lhs_path,
107+
rhs_path,
108+
max_depth)
109+
if type(lhs) == dict:
110+
lhs_keys = list(lhs.keys())
111+
rhs_keys = list(rhs.keys())
112+
result = compare(len(lhs_keys), len(rhs_keys))
113+
if result != 0:
114+
return result
115+
lhs_keys = sort_if_possible(lhs_keys, lhs_path, rhs_path, max_depth)
116+
rhs_keys = sort_if_possible(rhs_keys, lhs_path, rhs_path, max_depth)
117+
for lhs_key, rhs_key in zip(lhs_keys, rhs_keys):
118+
result = generic_object_comparison(
119+
lhs_key, rhs_key, lhs_path, rhs_path, max_depth)
120+
if result != 0:
121+
return result
122+
result = generic_object_comparison(
123+
lhs[lhs_key], rhs[rhs_key], lhs_path, rhs_path, max_depth)
124+
if result != 0:
125+
return result
126+
127+
lhs_fields = dir(lhs)
128+
rhs_fields = dir(rhs)
129+
result = compare(len(lhs_fields), len(rhs_fields))
130+
if result != 0:
131+
return result
132+
for i in range(len(lhs_fields)):
133+
result = compare(lhs_fields[i], rhs_fields[i])
134+
if result != 0:
135+
return result
136+
result = generic_object_comparison(
137+
getattr(lhs, lhs_fields[i], None),
138+
getattr(rhs, rhs_fields[i], None),
139+
lhs_path,
140+
rhs_path,
141+
max_depth)
142+
if result != 0:
143+
return result
144+
return 0
145+
146+
147+
def sort_if_possible(obj, lhs_path=None, rhs_path=None, max_depth=4):
148+
def cmp(lhs, rhs):
149+
if lhs_path is None:
150+
# Start the traversal at the root call to cmp.
151+
return generic_object_comparison(lhs, rhs, [], [], max_depth)
152+
else:
153+
# Continue the existing traversal path for recursive calls to cmp.
154+
return generic_object_comparison(lhs, rhs, lhs_path, rhs_path, max_depth)
155+
156+
return sorted(obj, key=functools.cmp_to_key(cmp))
157+
158+
159+
def save_set(pickler, obj):
160+
pickler.save_set(sort_if_possible(obj))
161+
162+
163+
def save_frozenset(pickler, obj):
164+
pickler.save_frozenset(sort_if_possible(obj))

0 commit comments

Comments
 (0)