Skip to content

Commit 56fb00f

Browse files
atorerocopybara-github
authored andcommitted
Add support for protos in from_py.
PiperOrigin-RevId: 858223810 Change-Id: Ide1ea43f958851e748431dde471801d611dd99d6
1 parent 593271f commit 56fb00f

File tree

12 files changed

+563
-34
lines changed

12 files changed

+563
-34
lines changed

py/koladata/base/BUILD

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,31 @@ cc_library(
227227
],
228228
)
229229

230+
cc_library(
231+
name = "py_proto_utils",
232+
srcs = ["py_proto_utils.cc"],
233+
hdrs = ["py_proto_utils.h"],
234+
deps = [
235+
"//koladata:data_bag",
236+
"//koladata:data_slice",
237+
"//koladata/internal:data_item",
238+
"//koladata/internal:data_slice",
239+
"//koladata/internal:dtype",
240+
"//koladata/operators:lib",
241+
"//koladata/proto:from_proto",
242+
"//py/koladata/types:pybind11_protobuf_wrapper",
243+
"@com_google_absl//absl/base:nullability",
244+
"@com_google_absl//absl/status:statusor",
245+
"@com_google_absl//absl/strings:string_view",
246+
"@com_google_absl//absl/types:span",
247+
"@com_google_arolla//arolla/util",
248+
"@com_google_arolla//arolla/util:status_backport",
249+
"@com_google_arolla//py/arolla/py_utils",
250+
"@com_google_protobuf//:protobuf",
251+
"@rules_python//python/cc:current_py_cc_headers", # buildcleaner: keep
252+
],
253+
)
254+
230255
koladata_py_extension(
231256
name = "py_functors_base_py_ext",
232257
srcs = ["py_functors_base_module.cc"],

py/koladata/base/py_conversions/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ cc_library(
7676
"//koladata/internal:schema_attrs",
7777
"//koladata/internal/op_utils:trampoline_executor",
7878
"//py/koladata/base:boxing",
79+
"//py/koladata/base:py_proto_utils",
80+
"//py/koladata/types:pybind11_protobuf_wrapper",
7981
"@com_google_absl//absl/base:core_headers",
8082
"@com_google_absl//absl/base:no_destructor",
8183
"@com_google_absl//absl/base:nullability",

py/koladata/base/py_conversions/from_py.cc

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@
5353
#include "koladata/uuid_utils.h"
5454
#include "py/koladata/base/boxing.h"
5555
#include "py/koladata/base/py_conversions/dataclasses_util.h"
56+
#include "py/koladata/base/py_proto_utils.h"
57+
#include "py/koladata/types/pybind11_protobuf_wrapper.h"
5658
#include "arolla/util/status_macros_backport.h"
5759

5860
namespace koladata::python {
@@ -228,6 +230,21 @@ class FromPyConverter {
228230
return false;
229231
}
230232

233+
// Returns true if the given Python objects should be treated as a proto.
234+
// Checks the first non-None object.
235+
absl::StatusOr<bool> IsProto(const std::vector<PyObject*>& py_objects) {
236+
if (py_objects.empty()) {
237+
return false;
238+
}
239+
for (PyObject* py_obj : py_objects) {
240+
if (Py_IsNone(py_obj)) {
241+
continue;
242+
}
243+
return IsPyProtoMessage(py_obj);
244+
}
245+
return false;
246+
}
247+
231248
// Verifies that dict_as_obj is not set for dict schema.
232249
absl::Status VerifyDictAsObj(const std::optional<DataSlice>& schema) {
233250
if (dict_as_obj_ && schema && schema->IsDictSchema()) {
@@ -512,6 +529,12 @@ class FromPyConverter {
512529
cur_depth, executor, result);
513530
}
514531

532+
ASSIGN_OR_RETURN(bool is_proto, IsProto(py_objects));
533+
if (is_proto) {
534+
return ConvertProto(py_objects, std::move(cur_shape), schema,
535+
std::move(itemid), result);
536+
}
537+
515538
if (IsEntity(py_objects, schema)) {
516539
DCHECK(schema.has_value());
517540
return ConvertEntities(py_objects, std::move(cur_shape), *schema,
@@ -1002,6 +1025,23 @@ class FromPyConverter {
10021025
return absl::OkStatus();
10031026
}
10041027

1028+
absl::Status ConvertProto(const std::vector<PyObject*>& py_objects,
1029+
DataSlice::JaggedShape cur_shape,
1030+
std::optional<DataSlice> schema,
1031+
std::optional<DataSlice> itemid,
1032+
std::optional<DataSlice>& result) {
1033+
if (itemid.has_value()) {
1034+
itemid = itemid->Flatten();
1035+
}
1036+
1037+
ASSIGN_OR_RETURN(DataSlice proto_slice,
1038+
FromProtoObjects(GetBag(), py_objects, /*extensions=*/{},
1039+
/*itemid=*/itemid, /*schema=*/schema));
1040+
1041+
ASSIGN_OR_RETURN(result, proto_slice.Reshape(std::move(cur_shape)));
1042+
return absl::OkStatus();
1043+
}
1044+
10051045
// Converts Python objects (dataclasses) to Objects/Entities.
10061046
// If the schema is kObject, creates an Object DataSlice. Otherwise (which
10071047
// can only happen if the schema is nullopt), creates an Entity DataSlice.

py/koladata/base/py_proto_utils.cc

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
//
15+
#include "py/koladata/base/py_proto_utils.h"
16+
17+
#include <any>
18+
#include <optional>
19+
#include <utility>
20+
#include <vector>
21+
22+
#include "absl/base/nullability.h"
23+
#include "absl/status/statusor.h"
24+
#include "absl/strings/string_view.h"
25+
#include "absl/types/span.h"
26+
#include "arolla/util/unit.h"
27+
#include "koladata/data_bag.h"
28+
#include "koladata/data_slice.h"
29+
#include "koladata/internal/data_item.h"
30+
#include "koladata/internal/dtype.h"
31+
#include "koladata/internal/slice_builder.h"
32+
#include "koladata/operators/slices.h"
33+
#include "koladata/proto/from_proto.h"
34+
#include "google/protobuf/message.h"
35+
#include "py/arolla/py_utils/py_utils.h"
36+
#include "py/koladata/types/pybind11_protobuf_wrapper.h"
37+
#include "arolla/util/status_macros_backport.h"
38+
39+
namespace koladata::python {
40+
absl::StatusOr<DataSlice> FromProtoObjects(
41+
const absl_nonnull DataBagPtr& db, const std::vector<PyObject*>& py_objects,
42+
absl::Span<const absl::string_view> extensions,
43+
const std::optional<DataSlice>& itemid,
44+
const std::optional<DataSlice>& schema) {
45+
arolla::python::DCheckPyGIL();
46+
const Py_ssize_t messages_list_len = py_objects.size();
47+
48+
internal::SliceBuilder message_mask_builder(messages_list_len);
49+
auto typed_message_mask_builder = message_mask_builder.typed<arolla::Unit>();
50+
std::vector<std::any> message_owners;
51+
message_owners.reserve(messages_list_len);
52+
std::vector<const ::google::protobuf::Message* absl_nonnull> message_ptrs;
53+
message_ptrs.reserve(messages_list_len);
54+
for (Py_ssize_t i = 0; i < messages_list_len; ++i) {
55+
PyObject* py_message = py_objects[i]; // Borrowed.
56+
if (py_message != Py_None) {
57+
typed_message_mask_builder.InsertIfNotSet(i, arolla::kUnit);
58+
59+
ASSIGN_OR_RETURN((auto [message_ptr, message_owner]),
60+
python::UnwrapPyProtoMessage(py_message));
61+
message_owners.push_back(std::move(message_owner));
62+
message_ptrs.push_back(message_ptr);
63+
}
64+
}
65+
ASSIGN_OR_RETURN(
66+
auto message_mask,
67+
DataSlice::Create(std::move(message_mask_builder).Build(),
68+
DataSlice::JaggedShape::FlatFromSize(messages_list_len),
69+
internal::DataItem(schema::kMask)));
70+
71+
ASSIGN_OR_RETURN(DataSlice dense_result,
72+
FromProto(db, message_ptrs, extensions, itemid, schema));
73+
ASSIGN_OR_RETURN(DataSlice result,
74+
ops::InverseSelect(dense_result, message_mask));
75+
return result;
76+
}
77+
} // namespace koladata::python

py/koladata/base/py_proto_utils.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
//
15+
// Utilities for converting Python proto messages to Koda objects.
16+
17+
#ifndef KOLADATA_BASE_PY_PROTO_UTILS_H_
18+
#define KOLADATA_BASE_PY_PROTO_UTILS_H_
19+
20+
#include <optional>
21+
#include <vector>
22+
23+
#include "Python.h"
24+
#include "absl/base/nullability.h"
25+
#include "absl/status/statusor.h"
26+
#include "absl/strings/string_view.h"
27+
#include "absl/types/span.h"
28+
#include "koladata/data_bag.h"
29+
#include "koladata/data_slice.h"
30+
31+
namespace koladata::python {
32+
// Treats Python objects (potentially sparse) as proto Messages and converts
33+
// them to Koda objects (to missing elements if None).
34+
// Basically, calls `FromProto` on the non-None objects, and sets the missing
35+
// elements for None elements. If an element is neither a proto Message nor
36+
// None, raises an error.
37+
absl::StatusOr<DataSlice> FromProtoObjects(
38+
const absl_nonnull DataBagPtr& db, const std::vector<PyObject*>& py_objects,
39+
absl::Span<const absl::string_view> extensions,
40+
const std::optional<DataSlice>& itemid,
41+
const std::optional<DataSlice>& schema);
42+
} // namespace koladata::python
43+
44+
#endif // KOLADATA_BASE_PY_PROTO_UTILS_H_

py/koladata/functions/tests/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,10 +565,12 @@ py_test(
565565
name = "from_py_test",
566566
srcs = ["from_py_test.py"],
567567
deps = [
568+
":test_py_pb2",
568569
"//py:python_path",
569570
"//py/koladata/functions",
570571
"//py/koladata/functions:attrs",
571572
"//py/koladata/functions:object_factories",
573+
"//py/koladata/functions:proto_conversions",
572574
"//py/koladata/functions:py_conversions",
573575
"//py/koladata/operators:kde_operators",
574576
"//py/koladata/testing",

0 commit comments

Comments
 (0)