Skip to content

Commit bdb3939

Browse files
claudevdmClaude
and
Claude
authored
Handle pickling top-level enums. (#34568)
* Handle pickling top-level enums. * Add tests. * Lint fix. * Add license header. --------- Co-authored-by: Claude <[email protected]>
1 parent 8c592aa commit bdb3939

File tree

4 files changed

+72
-31
lines changed

4 files changed

+72
-31
lines changed

Diff for: sdks/java/extensions/protobuf/src/test/proto/proto2_coder_test_messages.proto

+14
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,17 @@ message MessageWithMap {
5151
message ReferencesMessageWithMap {
5252
repeated MessageWithMap field1 = 1;
5353
}
54+
55+
enum TopLevelEnum {
56+
UNSPECIFIED = 0;
57+
ONE = 1;
58+
TWO = 2;
59+
}
60+
61+
message MessageD {
62+
enum NestedEnum {
63+
UNSPECIFIED = 0;
64+
ONE = 1;
65+
TWO = 2;
66+
}
67+
}

Diff for: sdks/python/apache_beam/coders/proto2_coder_test_messages_pb2.py

+41-27
Original file line numberDiff line numberDiff line change
@@ -17,42 +17,56 @@
1717

1818
# -*- coding: utf-8 -*-
1919
# Generated by the protocol buffer compiler. DO NOT EDIT!
20+
# NO CHECKED-IN PROTOBUF GENCODE
2021
# source: apache_beam/coders/proto2_coder_test_messages.proto
21-
22+
# Protobuf Python Version: 5.28.0
2223
"""Generated protocol buffer code."""
23-
from google.protobuf.internal import builder as _builder
2424
from google.protobuf import descriptor as _descriptor
2525
from google.protobuf import descriptor_pool as _descriptor_pool
26+
from google.protobuf import runtime_version as _runtime_version
2627
from google.protobuf import symbol_database as _symbol_database
28+
from google.protobuf.internal import builder as _builder
29+
_runtime_version.ValidateProtobufRuntimeVersion(
30+
_runtime_version.Domain.PUBLIC,
31+
5,
32+
28,
33+
0,
34+
'',
35+
'apache_beam/coders/proto2_coder_test_messages.proto'
36+
)
2737
# @@protoc_insertion_point(imports)
2838

2939
_sym_db = _symbol_database.Default()
3040

31-
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
32-
b'\n3apache_beam/coders/proto2_coder_test_messages.proto\x12\x1aproto2_coder_test_messages\"P\n\x08MessageA\x12\x0e\n\x06\x66ield1\x18\x01 \x01(\t\x12\x34\n\x06\x66ield2\x18\x02 \x03(\x0b\x32$.proto2_coder_test_messages.MessageB\"\x1a\n\x08MessageB\x12\x0e\n\x06\x66ield1\x18\x01 \x01(\x08\"\x10\n\x08MessageC*\x04\x08\x64\x10j\"\xad\x01\n\x0eMessageWithMap\x12\x46\n\x06\x66ield1\x18\x01 \x03(\x0b\x32\x36.proto2_coder_test_messages.MessageWithMap.Field1Entry\x1aS\n\x0b\x46ield1Entry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x33\n\x05value\x18\x02 \x01(\x0b\x32$.proto2_coder_test_messages.MessageA:\x02\x38\x01\"V\n\x18ReferencesMessageWithMap\x12:\n\x06\x66ield1\x18\x01 \x03(\x0b\x32*.proto2_coder_test_messages.MessageWithMap:Z\n\x06\x66ield1\x12$.proto2_coder_test_messages.MessageC\x18\x65 \x01(\x0b\x32$.proto2_coder_test_messages.MessageA:Z\n\x06\x66ield2\x12$.proto2_coder_test_messages.MessageC\x18\x66 \x01(\x0b\x32$.proto2_coder_test_messages.MessageBB)\n\'org.apache.beam.sdk.extensions.protobuf'
33-
)
3441

35-
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
36-
_builder.BuildTopDescriptorsAndMessages(
37-
DESCRIPTOR, 'apache_beam.coders.proto2_coder_test_messages_pb2', globals())
38-
if _descriptor._USE_C_DESCRIPTORS == False:
39-
MessageC.RegisterExtension(field1)
40-
MessageC.RegisterExtension(field2)
4142

42-
DESCRIPTOR._options = None
43-
DESCRIPTOR._serialized_options = b'\n\'org.apache.beam.sdk.extensions.protobuf'
44-
_MESSAGEWITHMAP_FIELD1ENTRY._options = None
45-
_MESSAGEWITHMAP_FIELD1ENTRY._serialized_options = b'8\001'
46-
_MESSAGEA._serialized_start = 83
47-
_MESSAGEA._serialized_end = 163
48-
_MESSAGEB._serialized_start = 165
49-
_MESSAGEB._serialized_end = 191
50-
_MESSAGEC._serialized_start = 193
51-
_MESSAGEC._serialized_end = 209
52-
_MESSAGEWITHMAP._serialized_start = 212
53-
_MESSAGEWITHMAP._serialized_end = 385
54-
_MESSAGEWITHMAP_FIELD1ENTRY._serialized_start = 302
55-
_MESSAGEWITHMAP_FIELD1ENTRY._serialized_end = 385
56-
_REFERENCESMESSAGEWITHMAP._serialized_start = 387
57-
_REFERENCESMESSAGEWITHMAP._serialized_end = 473
43+
44+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n3apache_beam/coders/proto2_coder_test_messages.proto\x12\x1aproto2_coder_test_messages\"P\n\x08MessageA\x12\x0e\n\x06\x66ield1\x18\x01 \x01(\t\x12\x34\n\x06\x66ield2\x18\x02 \x03(\x0b\x32$.proto2_coder_test_messages.MessageB\"\x1a\n\x08MessageB\x12\x0e\n\x06\x66ield1\x18\x01 \x01(\x08\"\x10\n\x08MessageC*\x04\x08\x64\x10j\"\xad\x01\n\x0eMessageWithMap\x12\x46\n\x06\x66ield1\x18\x01 \x03(\x0b\x32\x36.proto2_coder_test_messages.MessageWithMap.Field1Entry\x1aS\n\x0b\x46ield1Entry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x33\n\x05value\x18\x02 \x01(\x0b\x32$.proto2_coder_test_messages.MessageA:\x02\x38\x01\"V\n\x18ReferencesMessageWithMap\x12:\n\x06\x66ield1\x18\x01 \x03(\x0b\x32*.proto2_coder_test_messages.MessageWithMap\";\n\x08MessageD\"/\n\nNestedEnum\x12\x0f\n\x0bUNSPECIFIED\x10\x00\x12\x07\n\x03ONE\x10\x01\x12\x07\n\x03TWO\x10\x02*1\n\x0cTopLevelEnum\x12\x0f\n\x0bUNSPECIFIED\x10\x00\x12\x07\n\x03ONE\x10\x01\x12\x07\n\x03TWO\x10\x02:Z\n\x06\x66ield1\x12$.proto2_coder_test_messages.MessageC\x18\x65 \x01(\x0b\x32$.proto2_coder_test_messages.MessageA:Z\n\x06\x66ield2\x12$.proto2_coder_test_messages.MessageC\x18\x66 \x01(\x0b\x32$.proto2_coder_test_messages.MessageBB)\n\'org.apache.beam.sdk.extensions.protobuf')
45+
46+
_globals = globals()
47+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
48+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'apache_beam.coders.proto2_coder_test_messages_pb2', _globals)
49+
if not _descriptor._USE_C_DESCRIPTORS:
50+
_globals['DESCRIPTOR']._loaded_options = None
51+
_globals['DESCRIPTOR']._serialized_options = b'\n\'org.apache.beam.sdk.extensions.protobuf'
52+
_globals['_MESSAGEWITHMAP_FIELD1ENTRY']._loaded_options = None
53+
_globals['_MESSAGEWITHMAP_FIELD1ENTRY']._serialized_options = b'8\001'
54+
_globals['_TOPLEVELENUM']._serialized_start=536
55+
_globals['_TOPLEVELENUM']._serialized_end=585
56+
_globals['_MESSAGEA']._serialized_start=83
57+
_globals['_MESSAGEA']._serialized_end=163
58+
_globals['_MESSAGEB']._serialized_start=165
59+
_globals['_MESSAGEB']._serialized_end=191
60+
_globals['_MESSAGEC']._serialized_start=193
61+
_globals['_MESSAGEC']._serialized_end=209
62+
_globals['_MESSAGEWITHMAP']._serialized_start=212
63+
_globals['_MESSAGEWITHMAP']._serialized_end=385
64+
_globals['_MESSAGEWITHMAP_FIELD1ENTRY']._serialized_start=302
65+
_globals['_MESSAGEWITHMAP_FIELD1ENTRY']._serialized_end=385
66+
_globals['_REFERENCESMESSAGEWITHMAP']._serialized_start=387
67+
_globals['_REFERENCESMESSAGEWITHMAP']._serialized_end=473
68+
_globals['_MESSAGED']._serialized_start=475
69+
_globals['_MESSAGED']._serialized_end=534
70+
_globals['_MESSAGED_NESTEDENUM']._serialized_start=487
71+
_globals['_MESSAGED_NESTEDENUM']._serialized_end=534
5872
# @@protoc_insertion_point(module_scope)

Diff for: sdks/python/apache_beam/internal/cloudpickle_pickler.py

+5
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ def _reconstruct_enum_descriptor(full_name):
8686
if not hasattr(module, 'DESCRIPTOR'):
8787
continue
8888

89+
if hasattr(module.DESCRIPTOR, 'enum_types_by_name'):
90+
for (_, enum_desc) in module.DESCRIPTOR.enum_types_by_name.items():
91+
if enum_desc.full_name == full_name:
92+
return enum_desc
93+
8994
for _, attr_value in vars(module).items():
9095
if not hasattr(attr_value, 'DESCRIPTOR'):
9196
continue

Diff for: sdks/python/apache_beam/internal/cloudpickle_pickler_test.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,29 @@
2323
import types
2424
import unittest
2525

26+
from apache_beam.coders import proto2_coder_test_messages_pb2
2627
from apache_beam.internal import module_test
2728
from apache_beam.internal.cloudpickle_pickler import dumps
2829
from apache_beam.internal.cloudpickle_pickler import loads
29-
from apache_beam.portability.api import beam_runner_api_pb2
3030

3131

3232
class PicklerTest(unittest.TestCase):
3333

3434
NO_MAPPINGPROXYTYPE = not hasattr(types, "MappingProxyType")
3535

36-
def test_pickle_enum_descriptor(self):
37-
TimeDomain = beam_runner_api_pb2.TimeDomain.Enum
36+
def test_pickle_nested_enum_descriptor(self):
37+
NestedEnum = proto2_coder_test_messages_pb2.MessageD.NestedEnum
3838

3939
def fn():
40-
return TimeDomain.EVENT_TIME
40+
return NestedEnum.TWO
41+
42+
self.assertEqual(fn(), loads(dumps(fn))())
43+
44+
def test_pickle_top_level_enum_descriptor(self):
45+
TopLevelEnum = proto2_coder_test_messages_pb2.TopLevelEnum
46+
47+
def fn():
48+
return TopLevelEnum.ONE
4149

4250
self.assertEqual(fn(), loads(dumps(fn))())
4351

0 commit comments

Comments
 (0)