Skip to content

Commit 2044cdb

Browse files
nbjohnson0copybara-github
authored andcommitted
Allow kd.from_proto to use extensions from both an explicit schema and the extensions argument.
Previously, the `extensions` argument was ignored if there was an explicit schema. This behavior also wasn't clearly documented at the python layer. Also correct some outdated statements in the kd.from_proto docstring. PiperOrigin-RevId: 857273382 Change-Id: I97c35fa4d08b241721a9b17e81767f47e597be76
1 parent 730f2ea commit 2044cdb

File tree

6 files changed

+163
-39
lines changed

6 files changed

+163
-39
lines changed

docs/api_reference.md

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9820,14 +9820,14 @@ respectively. Enums are converted to INT32. The attribute names on the Koda
98209820
objects match the field names in the proto definition. See below for methods
98219821
to convert proto extensions to attributes alongside regular fields.
98229822

9823-
Messages, primitive fields, repeated fields, and maps are converted to
9824-
equivalent Koda structures. Enums are converted to ints.
9823+
If schema is not specified or schema is kd.OBJECT, only present fields in
9824+
`messages` are loaded and included in the converted schema. To get a schema
9825+
that contains all fields independent of the data, use `kd.schema_from_proto`.
98259826

9826-
Only present values in `messages` are added. Default and missing values are
9827-
not used.
9827+
Proto extensions are ignored by default unless `extensions` is specified, or
9828+
if an explicit entity schema with parenthesized attr names is specified. If
9829+
both are specified, we use the union of the two extension sets.
98289830

9829-
Proto extensions are ignored by default unless `extensions` is specified (or
9830-
if an explicit entity schema with parenthesized attrs is used).
98319831
The format of each extension specified in `extensions` is a dot-separated
98329832
sequence of field names and/or extension names, where extension names are
98339833
fully-qualified extension paths surrounded by parentheses. This sequence of
@@ -9839,6 +9839,9 @@ default behavior of traversing all fields. For example:
98399839
"path.to.map_field.values.(package_name.some_extension)"
98409840
"path.(package_name.some_extension).(package_name2.nested_extension)"
98419841

9842+
If an explicit entity schema attr name starts with "(" and ends with ")" it is
9843+
also interpreted as an extension name.
9844+
98429845
Extensions are looked up using the C++ generated descriptor pool, using
98439846
`DescriptorPool::FindExtensionByName`, which requires that all extensions are
98449847
compiled in as C++ protos. The Koda attribute names for the extension fields
@@ -9852,7 +9855,8 @@ be accessed using `.get_attr(name)'. For example,
98529855
ds.optional_field.get_attr('(package_name.DefExtension.def_extension)')
98539856

98549857
If `messages` is a single proto Message, the result is a DataItem. If it is a
9855-
list of proto Messages, the result is an 1D DataSlice.
9858+
nested list of proto Messages, the result is a DataSlice with the same number
9859+
of dimensions as the nesting level.
98569860

98579861
Args:
98589862
messages: Message or nested list/tuple of Message of the same type. Any of

koladata/proto/from_proto.cc

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,6 +1007,8 @@ absl::Status FromProtoMessage(
10071007
struct FieldVars {
10081008
const FieldDescriptor* absl_nonnull field_descriptor;
10091009
absl::string_view attr_name;
1010+
// Whether this is an extension field that is not on the requested schema.
1011+
bool is_non_schema_extension = false;
10101012
std::optional<DataSlice> value;
10111013
};
10121014

@@ -1033,7 +1035,11 @@ absl::Status FromProtoMessage(
10331035
// For explicit entity schemas, use the schema attr names as the list of
10341036
// fields and extensions to convert.
10351037
ASSIGN_OR_RETURN(vars->schema_attr_names, schema->GetAttrNames());
1036-
vars->fields.reserve(vars->schema_attr_names.size());
1038+
const int64_t max_num_fields =
1039+
vars->schema_attr_names.size() +
1040+
((extension_map != nullptr) ? extension_map->extension_fields.size()
1041+
: 0);
1042+
vars->fields.reserve(max_num_fields);
10371043
for (const auto& attr_name : vars->schema_attr_names) {
10381044
if (attr_name.starts_with('(') && attr_name.ends_with(')')) {
10391045
// Interpret attrs with parentheses as fully-qualified extension paths.
@@ -1053,11 +1059,11 @@ absl::Status FromProtoMessage(
10531059
field->full_name(), message_descriptor.full_name(),
10541060
field->containing_type()->full_name()));
10551061
}
1056-
vars->fields.emplace_back(field, attr_name);
1062+
vars->fields.emplace_back(field, attr_name, false);
10571063
} else {
10581064
const auto* field = message_descriptor.FindFieldByName(attr_name);
10591065
if (field != nullptr) {
1060-
vars->fields.emplace_back(field, attr_name);
1066+
vars->fields.emplace_back(field, attr_name, false);
10611067
}
10621068
}
10631069
}
@@ -1071,19 +1077,23 @@ absl::Status FromProtoMessage(
10711077
for (int i_field = 0; i_field < message_descriptor.field_count();
10721078
++i_field) {
10731079
const auto* field = message_descriptor.field(i_field);
1074-
vars->fields.emplace_back(field, field->name());
1080+
vars->fields.emplace_back(field, field->name(), false);
10751081
}
1076-
if (extension_map != nullptr) {
1077-
for (const auto& [attr_name, field] : extension_map->extension_fields) {
1078-
if (field->containing_type() != &message_descriptor) {
1079-
return absl::InvalidArgumentError(absl::StrFormat(
1080-
"extension \"%s\" exists, but isn't an extension "
1081-
"on target message type \"%s\", expected \"%s\"",
1082-
field->full_name(), message_descriptor.full_name(),
1083-
field->containing_type()->full_name()));
1084-
}
1085-
vars->fields.emplace_back(field, attr_name);
1082+
}
1083+
if (extension_map != nullptr) {
1084+
for (const auto& [attr_name, field] : extension_map->extension_fields) {
1085+
if (vars->schema_attr_names.contains(attr_name)) {
1086+
// Already handled by the schema.
1087+
continue;
1088+
}
1089+
if (field->containing_type() != &message_descriptor) {
1090+
return absl::InvalidArgumentError(absl::StrFormat(
1091+
"extension \"%s\" exists, but isn't an extension "
1092+
"on target message type \"%s\", expected \"%s\"",
1093+
field->full_name(), message_descriptor.full_name(),
1094+
field->containing_type()->full_name()));
10861095
}
1096+
vars->fields.emplace_back(field, attr_name, true);
10871097
}
10881098
}
10891099

@@ -1098,7 +1108,8 @@ absl::Status FromProtoMessage(
10981108
for (auto& field_vars : vars->fields) {
10991109
RETURN_IF_ERROR(FromProtoField(
11001110
db, field_vars.attr_name, field_vars.attr_name,
1101-
*field_vars.field_descriptor, messages, itemid, schema,
1111+
*field_vars.field_descriptor, messages, itemid,
1112+
field_vars.is_non_schema_extension ? std::nullopt : schema,
11021113
allocated_schema_metadata, extension_map,
11031114
/*ignore_field_presence=*/false, executor, field_vars.value));
11041115
}
@@ -1107,10 +1118,12 @@ absl::Status FromProtoMessage(
11071118
&result]() -> absl::Status {
11081119
std::vector<absl::string_view> value_attr_names;
11091120
std::vector<DataSlice> values;
1121+
bool has_non_schema_extensions = false;
11101122
for (auto& field_vars : vars->fields) {
11111123
if (field_vars.value.has_value()) {
11121124
values.push_back(*std::move(field_vars.value));
11131125
value_attr_names.push_back(field_vars.attr_name);
1126+
has_non_schema_extensions |= field_vars.is_non_schema_extension;
11141127
}
11151128
}
11161129

@@ -1126,13 +1139,13 @@ absl::Status FromProtoMessage(
11261139
/*itemid=*/vars->itemid));
11271140
} else { // schema != OBJECT
11281141
ASSIGN_OR_RETURN(
1129-
result,
1130-
EntityCreator::Shaped(db, std::move(result_shape),
1131-
/*attr_names=*/std::move(value_attr_names),
1132-
/*values=*/std::move(values),
1133-
/*schema=*/std::move(vars->requested_schema),
1134-
/*overwrite_schema=*/false,
1135-
/*itemid=*/vars->itemid));
1142+
result, EntityCreator::Shaped(
1143+
db, std::move(result_shape),
1144+
/*attr_names=*/std::move(value_attr_names),
1145+
/*values=*/std::move(values),
1146+
/*schema=*/std::move(vars->requested_schema),
1147+
/*overwrite_schema=*/has_non_schema_extensions,
1148+
/*itemid=*/vars->itemid));
11361149
}
11371150
} else { // schema == nullopt
11381151
ASSIGN_OR_RETURN(result, EntityCreator::Shaped(

koladata/proto/from_proto.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ namespace koladata {
5959
// `schema` were nullopt, but converted to OBJECT recursively.
6060
//
6161
// If `schema` is an entity schema, the set of attr names on that schema is the
62-
// set of fields and extensions that are converted, overriding the default
63-
// behavior and ignoring `extensions`. Attr names that start with '(' and end
64-
// with ')' are interpreted as fully-qualified extension names and cause the
62+
// set of fields and extensions that are converted, in addition to the
63+
// extensions in `extensions`. Attr names that start with '(' and end with ')'
64+
// are interpreted as fully-qualified extension names and cause the
6565
// corresponding extension to be converted if present on the message. If a
6666
// sub-schema of this schema is OBJECT, the corresponding sub-messages are
6767
// converted using the OBJECT rules above.

koladata/proto/from_proto_test.cc

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,6 +1148,56 @@ TEST(FromProtoTest, ExtensionViaSchema) {
11481148
IsOkAndHolds(IsEquivalentTo(test::DataSlice<bool>({true}, db))));
11491149
}
11501150

1151+
TEST(FromProtoTest, ExtensionViaSchemaExtensionsAndExtensionsListUnion) {
1152+
testing::ExampleMessage2 message;
1153+
message.SetExtension(koladata::testing::m2_bool_extension_field, false);
1154+
message.MutableExtension(koladata::testing::m2_message2_extension_field)
1155+
->SetExtension(koladata::testing::m2_bool_extension_field, true);
1156+
1157+
auto db = DataBag::EmptyMutable();
1158+
auto schema = test::EntitySchema(
1159+
{
1160+
"(koladata.testing.m2_message2_extension_field)",
1161+
},
1162+
{
1163+
test::EntitySchema(
1164+
{
1165+
"(koladata.testing.m2_bool_extension_field)",
1166+
},
1167+
{
1168+
test::Schema(schema::kBool),
1169+
},
1170+
db),
1171+
},
1172+
db);
1173+
1174+
ASSERT_OK_AND_ASSIGN(
1175+
auto result, FromProto(db, {&message},
1176+
{
1177+
"(koladata.testing.m2_bool_extension_field)",
1178+
},
1179+
std::nullopt, schema));
1180+
1181+
EXPECT_THAT(result.GetAttrNames(),
1182+
IsOkAndHolds(UnorderedElementsAreArray({
1183+
"(koladata.testing.m2_bool_extension_field)",
1184+
"(koladata.testing.m2_message2_extension_field)",
1185+
})));
1186+
EXPECT_THAT(result.GetAttr("(koladata.testing.m2_bool_extension_field)"),
1187+
IsOkAndHolds(IsEquivalentTo(test::DataSlice<bool>({false}, db))));
1188+
1189+
ASSERT_OK_AND_ASSIGN(
1190+
auto message_ext_field,
1191+
result.GetAttr("(koladata.testing.m2_message2_extension_field)"));
1192+
EXPECT_THAT(message_ext_field.GetAttrNames(),
1193+
IsOkAndHolds(UnorderedElementsAreArray({
1194+
"(koladata.testing.m2_bool_extension_field)",
1195+
})));
1196+
EXPECT_THAT(
1197+
message_ext_field.GetAttr("(koladata.testing.m2_bool_extension_field)"),
1198+
IsOkAndHolds(IsEquivalentTo(test::DataSlice<bool>({true}, db))));
1199+
}
1200+
11511201
TEST(FromProtoTest, InvalidExtensionPath) {
11521202
testing::ExampleMessage message;
11531203

py/koladata/functions/proto_conversions.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,14 @@ def from_proto(
9797
objects match the field names in the proto definition. See below for methods
9898
to convert proto extensions to attributes alongside regular fields.
9999
100-
Messages, primitive fields, repeated fields, and maps are converted to
101-
equivalent Koda structures. Enums are converted to ints.
100+
If schema is not specified or schema is kd.OBJECT, only present fields in
101+
`messages` are loaded and included in the converted schema. To get a schema
102+
that contains all fields independent of the data, use `kd.schema_from_proto`.
102103
103-
Only present values in `messages` are added. Default and missing values are
104-
not used.
104+
Proto extensions are ignored by default unless `extensions` is specified, or
105+
if an explicit entity schema with parenthesized attr names is specified. If
106+
both are specified, we use the union of the two extension sets.
105107
106-
Proto extensions are ignored by default unless `extensions` is specified (or
107-
if an explicit entity schema with parenthesized attrs is used).
108108
The format of each extension specified in `extensions` is a dot-separated
109109
sequence of field names and/or extension names, where extension names are
110110
fully-qualified extension paths surrounded by parentheses. This sequence of
@@ -116,6 +116,9 @@ def from_proto(
116116
"path.to.map_field.values.(package_name.some_extension)"
117117
"path.(package_name.some_extension).(package_name2.nested_extension)"
118118
119+
If an explicit entity schema attr name starts with "(" and ends with ")" it is
120+
also interpreted as an extension name.
121+
119122
Extensions are looked up using the C++ generated descriptor pool, using
120123
`DescriptorPool::FindExtensionByName`, which requires that all extensions are
121124
compiled in as C++ protos. The Koda attribute names for the extension fields
@@ -129,7 +132,8 @@ def from_proto(
129132
ds.optional_field.get_attr('(package_name.DefExtension.def_extension)')
130133
131134
If `messages` is a single proto Message, the result is a DataItem. If it is a
132-
list of proto Messages, the result is an 1D DataSlice.
135+
nested list of proto Messages, the result is a DataSlice with the same number
136+
of dimensions as the nesting level.
133137
134138
Args:
135139
messages: Message or nested list/tuple of Message of the same type. Any of

py/koladata/functions/tests/from_proto_test.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,59 @@ def test_extensions(self):
398398
'koladata.functions.testing.MessageAExtension',
399399
)
400400

401+
def test_extensions_from_arg_and_schema(self):
402+
m = test_pb2.MessageA(
403+
some_text='thing 1',
404+
)
405+
406+
m.message_set_extensions.Extensions[
407+
test_pb2.MessageAExtension.message_set_extension
408+
].extra = 1
409+
m.Extensions[test_pb2.MessageAExtension.message_a_extension].extra = 2
410+
411+
s = proto_conversions.schema_from_proto(
412+
test_pb2.MessageA,
413+
extensions=[
414+
'(koladata.functions.testing.MessageAExtension.message_a_extension)'
415+
],
416+
)
417+
x = proto_conversions.from_proto(
418+
m,
419+
schema=s,
420+
extensions=[
421+
'message_set_extensions.(koladata.functions.testing.MessageAExtension.message_set_extension)'
422+
],
423+
)
424+
425+
self.assertCountEqual(
426+
x.get_attr_names(intersection=True),
427+
[
428+
'some_text',
429+
'some_float',
430+
'message_b_list',
431+
'message_set_extensions',
432+
'(koladata.functions.testing.MessageAExtension.message_a_extension)',
433+
],
434+
)
435+
self.assertEqual(
436+
x.get_attr(
437+
'(koladata.functions.testing.MessageAExtension.message_a_extension)'
438+
).extra,
439+
2,
440+
)
441+
self.assertCountEqual(
442+
x.message_set_extensions.get_attr_names(intersection=True),
443+
[
444+
'(koladata.functions.testing.MessageAExtension.message_set_extension)'
445+
],
446+
)
447+
self.assertEqual(
448+
x.message_set_extensions.get_attr(
449+
'(koladata.functions.testing.MessageAExtension.message_set_extension)'
450+
).extra,
451+
1,
452+
)
453+
401454
def test_extension_on_wrong_message_error(self):
402455
with self.assertRaisesRegex(
403456
ValueError,

0 commit comments

Comments
 (0)