Skip to content

Commit 56164ae

Browse files
nbjohnson0copybara-github
authored andcommitted
Add kd.to_proto_any function.
Fix sparsity handling in `kd.proto.get_proto_full_name`. PiperOrigin-RevId: 857219089 Change-Id: I0e968a5f715ba29bd599c7e4e8280c442cbc96b8
1 parent e69e716 commit 56164ae

File tree

9 files changed

+230
-4
lines changed

9 files changed

+230
-4
lines changed

docs/api_reference.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10804,6 +10804,25 @@ Args:
1080410804
Returns:
1080510805
A converted proto message or list of converted proto messages.</code></pre>
1080610806

10807+
### `kd.to_proto_any(x: DataSlice, *, descriptor_pool: DescriptorPool | None = None, deterministic: bool = False) -> Any | list[_NestedAnyMessageList] | None` {#kd.to_proto_any}
10808+
10809+
<pre class="no-copy"><code class="lang-text no-auto-prettify">Converts a DataSlice or DataItem to proto Any messages.
10810+
10811+
The schemas of all present values in `x` must have been derived from a proto
10812+
schema using `from_proto` or `schema_from_proto`, so that the original names
10813+
of the message types are embedded in the schema. Otherwise, this will fail.
10814+
10815+
Args:
10816+
x: DataSlice to convert.
10817+
descriptor_pool: Overrides the descriptor pool used to look up python proto
10818+
message classes based on proto message type full name. If None, the
10819+
default descriptor pool is used.
10820+
deterministic: Passed to Any.Pack.
10821+
10822+
Returns:
10823+
A proto Any message or nested list of proto Any messages with the same
10824+
shape as the input. Missing elements in the input are None in the output.</code></pre>
10825+
1080710826
### `kd.to_proto_bytes(x, proto_path, /)` {#kd.to_proto_bytes}
1080810827

1080910828
Alias for [kd.proto.to_proto_bytes](#kd.proto.to_proto_bytes) operator.

koladata/operators/proto.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,11 @@ absl::StatusOr<DataSlice> GetProtoFullName(const DataSlice& x) {
304304
// appropriate all-missing slice, we can reuse it as the return value.
305305
return x;
306306
} else {
307-
ASSIGN_OR_RETURN(schema, ExpandToShape(x.GetSchema(), x.GetShape(), 0));
307+
// schema = x.get_schema() & kd.has(x)
308+
schema = x.GetSchema();
309+
ASSIGN_OR_RETURN(auto has_x, ops::Has(x));
310+
ASSIGN_OR_RETURN(schema,
311+
ops::ApplyMask(std::move(schema), std::move(has_x)));
308312
}
309313

310314
ASSIGN_OR_RETURN(auto schema_metadata,

py/koladata/functions/functions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@
114114
get_attr_names = _attrs.get_attr_names
115115
# TODO: b/435124266 - Add to_proto explicitly to kd.py.
116116
to_proto = _proto_conversions.to_proto
117+
to_proto_any = _proto_conversions.to_proto_any
117118

118119
slice = _data_slice.DataSlice.from_vals # pylint: disable=redefined-builtin
119120
item = _data_item.DataItem.from_vals

py/koladata/functions/proto_conversions.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""Koda functions for converting to and from protocol buffers."""
1616

1717
from collections.abc import Iterator
18-
from typing import Type, TypeAlias, TypeVar, cast
18+
from typing import Any, Type, TypeAlias, TypeVar, cast
1919

2020
from google.protobuf import any as protobuf_any
2121
from google.protobuf import descriptor_pool as protobuf_descriptor_pool
@@ -47,6 +47,9 @@
4747
| tuple['_NestedMessageContainer', ...]
4848
| None
4949
)
50+
_NestedAnyMessageList: TypeAlias = (
51+
any_pb2.Any | list['_NestedAnyMessageList'] | None
52+
)
5053
_NestedAnyMessageContainer: TypeAlias = (
5154
any_pb2.Any
5255
| list['_NestedAnyMessageContainer']
@@ -73,7 +76,7 @@ def _flatten(x: _NestedMessageContainer) -> Iterator[message.Message | None]:
7376

7477
# Note: could use `tree.unflatten_as`, but it's not worth adding an additional
7578
# third-party dependency just for this.
76-
def _unflatten(shape: _NestedNoneList, it: Iterator[_T]) -> list[_T]:
79+
def _unflatten(shape: _NestedNoneList, it: Iterator[Any]) -> Any:
7780
"""Unflattens an iterator into a shape given by a nested list with None leaves."""
7881
return [_unflatten(x, it) for x in shape] if shape is not None else next(it)
7982

@@ -372,3 +375,58 @@ def to_proto(
372375
x_shape = (x & mask_constants.missing).to_py()
373376
results_flat = x.flatten()._to_proto(message_class) # pylint: disable=protected-access
374377
return _unflatten(x_shape, iter(results_flat))
378+
379+
380+
def to_proto_any(
381+
x: data_slice.DataSlice,
382+
*,
383+
descriptor_pool: protobuf_descriptor_pool.DescriptorPool | None = None,
384+
deterministic: bool = False,
385+
) -> _NestedAnyMessageList:
386+
"""Converts a DataSlice or DataItem to proto Any messages.
387+
388+
The schemas of all present values in `x` must have been derived from a proto
389+
schema using `from_proto` or `schema_from_proto`, so that the original names
390+
of the message types are embedded in the schema. Otherwise, this will fail.
391+
392+
Args:
393+
x: DataSlice to convert.
394+
descriptor_pool: Overrides the descriptor pool used to look up python proto
395+
message classes based on proto message type full name. If None, the
396+
default descriptor pool is used.
397+
deterministic: Passed to Any.Pack.
398+
399+
Returns:
400+
A proto Any message or nested list of proto Any messages with the same
401+
shape as the input. Missing elements in the input are None in the output.
402+
"""
403+
if descriptor_pool is None:
404+
descriptor_pool = protobuf_descriptor_pool.Default()
405+
406+
x_shape = (x & mask_constants.missing).to_py()
407+
x_flat = x.flatten()
408+
full_names = expr_eval.eval(
409+
kde_operators.kde.proto.get_proto_full_name(x_flat)
410+
).to_py()
411+
results = []
412+
for i, full_name in enumerate(full_names):
413+
x_item = x_flat.S[i]
414+
if full_name is None:
415+
if not x_item.is_empty():
416+
raise ValueError(
417+
'to_proto_any expects entities converted from proto messages, got'
418+
f' {x_item}'
419+
)
420+
else:
421+
results.append(None)
422+
else:
423+
message_descriptor = descriptor_pool.FindMessageTypeByName(full_name)
424+
message_type = protobuf_message_factory.GetMessageClass(
425+
message_descriptor
426+
)
427+
m = to_proto(x_item, message_type)
428+
result = any_pb2.Any()
429+
result.Pack(m, deterministic=deterministic)
430+
results.append(result)
431+
432+
return _unflatten(x_shape, iter(results))

py/koladata/functions/tests/BUILD

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -986,6 +986,20 @@ py_test(
986986
],
987987
)
988988

989+
py_test(
990+
name = "to_proto_any_test",
991+
srcs = ["to_proto_any_test.py"],
992+
deps = [
993+
":test_py_pb2",
994+
"//py:python_path",
995+
"//py/koladata/functions:proto_conversions",
996+
"//py/koladata/testing",
997+
"//py/koladata/types:data_slice",
998+
"@com_google_absl_py//absl/testing:absltest",
999+
"@com_google_protobuf//:protobuf_python",
1000+
],
1001+
)
1002+
9891003
proto_library(
9901004
name = "test_proto",
9911005
srcs = ["test.proto"],
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from absl.testing import absltest
16+
from google.protobuf import descriptor_pool
17+
from koladata.functions import proto_conversions
18+
from koladata.functions.tests import test_pb2
19+
from koladata.testing import testing
20+
from koladata.types import data_slice
21+
from google.protobuf import any_pb2
22+
23+
ds = data_slice.DataSlice.from_vals
24+
25+
26+
class ToProtoAnyTest(absltest.TestCase):
27+
28+
def test_zero_items(self):
29+
x = ds([])
30+
res = proto_conversions.to_proto_any(x)
31+
self.assertEqual(res, [])
32+
33+
def test_single_item(self):
34+
m = test_pb2.MessageA(some_text='thing 1')
35+
x = proto_conversions.from_proto(m)
36+
res = proto_conversions.to_proto_any(x)
37+
self.assertIsInstance(res, any_pb2.Any)
38+
unpacked_m = test_pb2.MessageA()
39+
res.Unpack(unpacked_m)
40+
self.assertEqual(unpacked_m, m)
41+
42+
def test_single_none(self):
43+
x = ds(None)
44+
res = proto_conversions.to_proto_any(x)
45+
self.assertIsNone(res)
46+
47+
def test_list_with_none(self):
48+
m1 = test_pb2.MessageA(some_text='thing 1')
49+
x = proto_conversions.from_proto([m1, None])
50+
res = proto_conversions.to_proto_any(x)
51+
self.assertIsInstance(res, list)
52+
self.assertLen(res, 2)
53+
self.assertIsInstance(res[0], any_pb2.Any)
54+
self.assertIsNone(res[1])
55+
unpacked_m1 = test_pb2.MessageA()
56+
res[0].Unpack(unpacked_m1)
57+
self.assertEqual(unpacked_m1, m1)
58+
59+
def test_multiple_messages_same_type(self):
60+
m1 = test_pb2.MessageA(some_text='thing 1')
61+
m2 = test_pb2.MessageA(some_text='thing 2')
62+
x = proto_conversions.from_proto([m1, m2])
63+
res = proto_conversions.to_proto_any(x)
64+
self.assertIsInstance(res, list)
65+
self.assertLen(res, 2)
66+
unpacked_m1 = test_pb2.MessageA()
67+
res[0].Unpack(unpacked_m1)
68+
self.assertEqual(unpacked_m1, m1)
69+
unpacked_m2 = test_pb2.MessageA()
70+
res[1].Unpack(unpacked_m2)
71+
self.assertEqual(unpacked_m2, m2)
72+
73+
def test_multiple_messages_different_types(self):
74+
m1 = test_pb2.MessageA(some_text='thing 1')
75+
m2 = test_pb2.MessageB(text='thing 2')
76+
# Need to use from_proto_any to get them into the same DataSlice
77+
any_m1 = any_pb2.Any()
78+
any_m1.Pack(m1)
79+
any_m2 = any_pb2.Any()
80+
any_m2.Pack(m2)
81+
x = proto_conversions.from_proto_any([any_m1, any_m2])
82+
83+
res = proto_conversions.to_proto_any(x)
84+
self.assertIsInstance(res, list)
85+
self.assertLen(res, 2)
86+
unpacked_m1 = test_pb2.MessageA()
87+
res[0].Unpack(unpacked_m1)
88+
self.assertEqual(unpacked_m1, m1)
89+
unpacked_m2 = test_pb2.MessageB()
90+
res[1].Unpack(unpacked_m2)
91+
self.assertEqual(unpacked_m2, m2)
92+
93+
def test_nested_list_input(self):
94+
m1 = test_pb2.MessageA(some_text='1')
95+
any_m1 = any_pb2.Any()
96+
any_m1.Pack(m1)
97+
m2 = test_pb2.MessageA(some_text='2')
98+
any_m2 = any_pb2.Any()
99+
any_m2.Pack(m2)
100+
m3 = test_pb2.MessageB(text='3')
101+
any_m3 = any_pb2.Any()
102+
any_m3.Pack(m3)
103+
104+
x = proto_conversions.from_proto_any([[any_m1, None, any_m2], [], [any_m3]])
105+
res = proto_conversions.to_proto_any(x)
106+
testing.assert_equivalent(proto_conversions.from_proto_any(res), x)
107+
108+
def test_not_from_proto(self):
109+
x = ds([1, 2, 3]).implode()
110+
with self.assertRaisesRegex(
111+
ValueError,
112+
'to_proto_any expects entities converted from proto messages',
113+
):
114+
proto_conversions.to_proto_any(x)
115+
116+
def test_empty_descriptor_pool(self):
117+
m = test_pb2.MessageA(some_text='thing 1')
118+
x = proto_conversions.from_proto(m)
119+
pool = descriptor_pool.DescriptorPool()
120+
with self.assertRaisesRegex(KeyError, 'MessageA'):
121+
proto_conversions.to_proto_any(x, descriptor_pool=pool)
122+
123+
124+
if __name__ == '__main__':
125+
absltest.main()

py/koladata/kd_dynamic.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,7 @@ to_json = _operators_json.to_json
831831
to_none = _operators_schema.to_none
832832
to_object = _operators_schema.to_object
833833
to_proto = _functions_proto_conversions.to_proto
834+
to_proto_any = _functions_proto_conversions.to_proto_any
834835
to_proto_bytes = _operators_proto.to_proto_bytes
835836
to_proto_json = _operators_proto.to_proto_json
836837
to_schema = _operators_schema.to_schema

py/koladata/operators/kde_operators_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
'from_pytree',
3939
'py_reference',
4040
'to_proto',
41+
'to_proto_any',
4142
'to_py',
4243
'to_pylist',
4344
'to_pytree',

py/koladata/operators/tests/proto_get_proto_full_name_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,13 @@ def test_from_proto(self):
5454

5555
def test_from_proto_slice(self):
5656
m = test_pb2.MessageA(some_text='hello')
57-
kd_m = proto_conversions.from_proto([m, m])
57+
kd_m = proto_conversions.from_proto([m, None, m])
5858
result = expr_eval.eval(kde.proto.get_proto_full_name(kd_m))
5959
testing.assert_equal(
6060
result.no_bag(),
6161
ds([
6262
'koladata.functions.testing.MessageA',
63+
None,
6364
'koladata.functions.testing.MessageA',
6465
]),
6566
)
@@ -69,13 +70,15 @@ def test_from_proto_object_slice(self):
6970
m2 = test_pb2.MessageB(text='bar')
7071
kd_m = ds([
7172
fns.obj(proto_conversions.from_proto(m1)),
73+
None,
7274
fns.obj(proto_conversions.from_proto(m2)),
7375
])
7476
result = expr_eval.eval(kde.proto.get_proto_full_name(kd_m))
7577
testing.assert_equal(
7678
result.no_bag(),
7779
ds([
7880
'koladata.functions.testing.MessageA',
81+
None,
7982
'koladata.functions.testing.MessageB',
8083
]),
8184
)

0 commit comments

Comments
 (0)