|
| 1 | +// Copyright 2025 Google LLC |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | +// |
| 15 | +#include "py/koladata/base/py_proto_utils.h" |
| 16 | + |
| 17 | +#include <any> |
| 18 | +#include <optional> |
| 19 | +#include <utility> |
| 20 | +#include <vector> |
| 21 | + |
| 22 | +#include "absl/base/nullability.h" |
| 23 | +#include "absl/status/statusor.h" |
| 24 | +#include "absl/strings/string_view.h" |
| 25 | +#include "absl/types/span.h" |
| 26 | +#include "arolla/util/unit.h" |
| 27 | +#include "koladata/data_bag.h" |
| 28 | +#include "koladata/data_slice.h" |
| 29 | +#include "koladata/internal/data_item.h" |
| 30 | +#include "koladata/internal/dtype.h" |
| 31 | +#include "koladata/internal/slice_builder.h" |
| 32 | +#include "koladata/operators/slices.h" |
| 33 | +#include "koladata/proto/from_proto.h" |
| 34 | +#include "google/protobuf/message.h" |
| 35 | +#include "py/arolla/py_utils/py_utils.h" |
| 36 | +#include "py/koladata/types/pybind11_protobuf_wrapper.h" |
| 37 | +#include "arolla/util/status_macros_backport.h" |
| 38 | + |
| 39 | +namespace koladata::python { |
| 40 | +absl::StatusOr<DataSlice> FromProtoObjects( |
| 41 | + const absl_nonnull DataBagPtr& db, const std::vector<PyObject*>& py_objects, |
| 42 | + absl::Span<const absl::string_view> extensions, |
| 43 | + const std::optional<DataSlice>& itemid, |
| 44 | + const std::optional<DataSlice>& schema) { |
| 45 | + arolla::python::DCheckPyGIL(); |
| 46 | + const Py_ssize_t messages_list_len = py_objects.size(); |
| 47 | + |
| 48 | + internal::SliceBuilder message_mask_builder(messages_list_len); |
| 49 | + auto typed_message_mask_builder = message_mask_builder.typed<arolla::Unit>(); |
| 50 | + std::vector<std::any> message_owners; |
| 51 | + message_owners.reserve(messages_list_len); |
| 52 | + std::vector<const ::google::protobuf::Message* absl_nonnull> message_ptrs; |
| 53 | + message_ptrs.reserve(messages_list_len); |
| 54 | + for (Py_ssize_t i = 0; i < messages_list_len; ++i) { |
| 55 | + PyObject* py_message = py_objects[i]; // Borrowed. |
| 56 | + if (py_message != Py_None) { |
| 57 | + typed_message_mask_builder.InsertIfNotSet(i, arolla::kUnit); |
| 58 | + |
| 59 | + ASSIGN_OR_RETURN((auto [message_ptr, message_owner]), |
| 60 | + python::UnwrapPyProtoMessage(py_message)); |
| 61 | + message_owners.push_back(std::move(message_owner)); |
| 62 | + message_ptrs.push_back(message_ptr); |
| 63 | + } |
| 64 | + } |
| 65 | + ASSIGN_OR_RETURN( |
| 66 | + auto message_mask, |
| 67 | + DataSlice::Create(std::move(message_mask_builder).Build(), |
| 68 | + DataSlice::JaggedShape::FlatFromSize(messages_list_len), |
| 69 | + internal::DataItem(schema::kMask))); |
| 70 | + |
| 71 | + ASSIGN_OR_RETURN(DataSlice dense_result, |
| 72 | + FromProto(db, message_ptrs, extensions, itemid, schema)); |
| 73 | + ASSIGN_OR_RETURN(DataSlice result, |
| 74 | + ops::InverseSelect(dense_result, message_mask)); |
| 75 | + return result; |
| 76 | +} |
| 77 | +} // namespace koladata::python |
0 commit comments