@@ -98,12 +98,14 @@ class ToPyVisitor : internal::AbstractVisitor {
9898 public:
9999 ToPyVisitor (bool obj_as_dict, bool include_missing_attrs,
100100 const absl::flat_hash_set<ObjectId>& objects_not_to_convert,
101- DataBagPtr db, internal::DataBagImpl::FallbackSpan fallback_span)
101+ DataBagPtr db, internal::DataBagImpl::FallbackSpan fallback_span,
102+ PyObject* output_class)
102103 : obj_as_dict_(obj_as_dict),
103104 include_missing_attrs_ (include_missing_attrs),
104105 objects_not_to_convert_(objects_not_to_convert),
105106 db_(std::move(db)),
106- fallback_span_(fallback_span) {}
107+ fallback_span_(fallback_span),
108+ output_class_(PyObjectPtr::NewRef(output_class)) {}
107109
108110 ItemToPyConverter GetItemToPyConverter (const DataItem& schema) {
109111 return [&](const DataItem& item) -> absl::StatusOr<PyObjectPtr> {
@@ -408,26 +410,34 @@ class ToPyVisitor : internal::AbstractVisitor {
408410 DataClassesUtil dataclasses_util_;
409411 DataBagPtr db_;
410412 internal::DataBagImpl::FallbackSpan fallback_span_;
413+ PyObjectPtr output_class_;
411414};
412415
413416PyObject* absl_nullable ToPyImplInternal (
414417 const DataSlice& ds, DataBagPtr bag, bool obj_as_dict,
415- bool include_missing_attrs, const absl::flat_hash_set<ObjectId>& leaf_ids) {
418+ bool include_missing_attrs, const absl::flat_hash_set<ObjectId>& leaf_ids,
419+ PyObject* output_class) {
416420 if (ds.IsEmpty () || bag == nullptr ||
417421 GetNarrowedSchema (ds).is_primitive_schema ()) {
418422 ASSIGN_OR_RETURN (PyObjectPtr res, PyObjectFromDataSlice (ds),
419423 arolla::python::SetPyErrFromStatus (_));
420424 return res.release ();
421425 }
426+ if (obj_as_dict && output_class != Py_None) {
427+ arolla::python::SetPyErrFromStatus (absl::InvalidArgumentError (
428+ " obj_as_dict cannot be used with output_class" ));
429+ return nullptr ;
430+ }
422431 const DataSlice& schema = ds.GetSchema ();
423432 DCHECK (bag != nullptr );
424433
425434 FlattenFallbackFinder fb_finder (*bag);
426435 const internal::DataBagImpl::FallbackSpan fallback_span =
427436 fb_finder.GetFlattenFallbacks ();
428437 // We use original DataBag in ToPyVisitor.
429- std::shared_ptr<ToPyVisitor> visitor = std::make_shared<ToPyVisitor>(
430- obj_as_dict, include_missing_attrs, leaf_ids, bag, fallback_span);
438+ std::shared_ptr<ToPyVisitor> visitor =
439+ std::make_shared<ToPyVisitor>(obj_as_dict, include_missing_attrs,
440+ leaf_ids, bag, fallback_span, output_class);
431441 // We use extracted DataBag for traversal.
432442 internal::Traverser<ToPyVisitor> traverse_op (bag->GetImpl (), fallback_span,
433443 visitor);
@@ -451,7 +461,8 @@ PyObject* absl_nullable ToPyImplInternal(
451461
452462PyObject* absl_nullable ToPyImpl (const DataSlice& ds, DataBagPtr bag,
453463 int max_depth, bool obj_as_dict,
454- bool include_missing_attrs) {
464+ bool include_missing_attrs,
465+ PyObject* output_class) {
455466 // When `max_depth != -1`, we want objects/dicts/lists at `max_depth`
456467 // to be kept as DataItems and not converted to Python objects.
457468 // To do that, we extract a DataSlice, and keep track of the leaf DataItems.
@@ -474,10 +485,11 @@ PyObject* absl_nullable ToPyImpl(const DataSlice& ds, DataBagPtr bag,
474485 /* casting_callback=*/ std::nullopt , std::move (leaf_callback)),
475486 arolla::python::SetPyErrFromStatus (_));
476487 return ToPyImplInternal (extracted_ds, ds.GetBag (), obj_as_dict,
477- include_missing_attrs, objects_not_to_convert);
488+ include_missing_attrs, objects_not_to_convert,
489+ output_class);
478490 }
479491 return ToPyImplInternal (ds, nullptr , obj_as_dict, include_missing_attrs,
480- objects_not_to_convert);
492+ objects_not_to_convert, output_class );
481493}
482494
483495} // namespace
@@ -487,9 +499,9 @@ PyObject* absl_nullable PyDataSlice_to_py(PyObject* self,
487499 Py_ssize_t nargs) {
488500 arolla::python::DCheckPyGIL ();
489501 arolla::python::PyCancellationScope cancellation_scope;
490- if (nargs != 3 ) {
502+ if (nargs != 4 ) {
491503 PyErr_Format (PyExc_ValueError,
492- " DataSlice._to_py_impl accepts exactly 3 arguments, got %d" ,
504+ " DataSlice._to_py_impl accepts exactly 4 arguments, got %d" ,
493505 nargs);
494506 return nullptr ;
495507 }
@@ -518,9 +530,15 @@ PyObject* absl_nullable PyDataSlice_to_py(PyObject* self,
518530 return nullptr ;
519531 }
520532 const bool include_missing_attrs = py_args[2 ] == Py_True;
533+ PyObject* output_class = py_args[3 ];
534+ if (output_class == nullptr ) {
535+ PyErr_Format (PyExc_TypeError,
536+ " expecting output_class to be a class, got nullptr" );
537+ return nullptr ;
538+ }
521539
522540 return ToPyImpl (ds, ds.GetBag (), max_depth, obj_as_dict,
523- include_missing_attrs);
541+ include_missing_attrs, output_class );
524542}
525543
526544} // namespace koladata::python
0 commit comments