Skip to content

Commit 646e530

Browse files
atorerocopybara-github
authored andcommitted
Add output_class argument to to_py. It will be actually implemented as a follow-up.
PiperOrigin-RevId: 875733422 Change-Id: Id8aa2246b9c7ce79837c4dc8c333edeec539e510
1 parent 8b7a761 commit 646e530

File tree

6 files changed

+49
-22
lines changed

6 files changed

+49
-22
lines changed

docs/api_reference.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10901,7 +10901,7 @@ Alias for [kd.proto.to_proto_bytes](#kd.proto.to_proto_bytes) operator.
1090110901

1090210902
Alias for [kd.proto.to_proto_json](#kd.proto.to_proto_json) operator.
1090310903

10904-
### `kd.to_py(ds: DataSlice, max_depth: int = 2, obj_as_dict: bool = False, include_missing_attrs: bool = True) -> Any` {#kd.to_py}
10904+
### `kd.to_py(ds: DataSlice, max_depth: int = 2, obj_as_dict: bool = False, include_missing_attrs: bool = True, output_class: type[Any] | None = None) -> Any` {#kd.to_py}
1090510905

1090610906
<pre class="no-copy"><code class="lang-text no-auto-prettify">Returns a readable python object from a DataSlice.
1090710907

@@ -10915,7 +10915,8 @@ Args:
1091510915
obj_as_dict: Whether to convert objects to python dicts. By default objects
1091610916
are converted to automatically constructed &#39;Obj&#39; dataclass instances.
1091710917
include_missing_attrs: whether to include attributes with None value in
10918-
objects.</code></pre>
10918+
objects.
10919+
output_class: If not None, will be used recursively as the output type.</code></pre>
1091910920

1092010921
### `kd.to_pylist(x: DataSlice) -> list[Any]` {#kd.to_pylist}
1092110922

@@ -13454,7 +13455,7 @@ Args:
1345413455
Returns:
1345513456
A new DataSlice with items selected by indices.</code></pre>
1345613457

13457-
### `DataSlice.to_py(ds: DataSlice, max_depth: int = 2, obj_as_dict: bool = False, include_missing_attrs: bool = True) -> Any` {#DataSlice.to_py}
13458+
### `DataSlice.to_py(ds: DataSlice, max_depth: int = 2, obj_as_dict: bool = False, include_missing_attrs: bool = True, output_class: Any | None = None) -> Any` {#DataSlice.to_py}
1345813459

1345913460
<pre class="no-copy"><code class="lang-text no-auto-prettify">Returns a readable python object from a DataSlice.
1346013461

@@ -13467,7 +13468,8 @@ Args:
1346713468
obj_as_dict: Whether to convert objects to python dicts. By default objects
1346813469
are converted to automatically constructed &#39;Obj&#39; dataclass instances.
1346913470
include_missing_attrs: whether to include attributes with None value in
13470-
objects.</code></pre>
13471+
objects.
13472+
output_class: If not None, will be used recursively as the output type.</code></pre>
1347113473

1347213474
### `DataSlice.to_pytree(ds: DataSlice, max_depth: int = 2, include_missing_attrs: bool = True) -> Any` {#DataSlice.to_pytree}
1347313475

py/koladata/base/py_conversions/to_py.cc

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,14 @@ class ToPyVisitor : internal::AbstractVisitor {
9898
public:
9999
ToPyVisitor(bool obj_as_dict, bool include_missing_attrs,
100100
const absl::flat_hash_set<ObjectId>& objects_not_to_convert,
101-
DataBagPtr db, internal::DataBagImpl::FallbackSpan fallback_span)
101+
DataBagPtr db, internal::DataBagImpl::FallbackSpan fallback_span,
102+
PyObject* output_class)
102103
: obj_as_dict_(obj_as_dict),
103104
include_missing_attrs_(include_missing_attrs),
104105
objects_not_to_convert_(objects_not_to_convert),
105106
db_(std::move(db)),
106-
fallback_span_(fallback_span) {}
107+
fallback_span_(fallback_span),
108+
output_class_(PyObjectPtr::NewRef(output_class)) {}
107109

108110
ItemToPyConverter GetItemToPyConverter(const DataItem& schema) {
109111
return [&](const DataItem& item) -> absl::StatusOr<PyObjectPtr> {
@@ -408,26 +410,34 @@ class ToPyVisitor : internal::AbstractVisitor {
408410
DataClassesUtil dataclasses_util_;
409411
DataBagPtr db_;
410412
internal::DataBagImpl::FallbackSpan fallback_span_;
413+
PyObjectPtr output_class_;
411414
};
412415

413416
PyObject* absl_nullable ToPyImplInternal(
414417
const DataSlice& ds, DataBagPtr bag, bool obj_as_dict,
415-
bool include_missing_attrs, const absl::flat_hash_set<ObjectId>& leaf_ids) {
418+
bool include_missing_attrs, const absl::flat_hash_set<ObjectId>& leaf_ids,
419+
PyObject* output_class) {
416420
if (ds.IsEmpty() || bag == nullptr ||
417421
GetNarrowedSchema(ds).is_primitive_schema()) {
418422
ASSIGN_OR_RETURN(PyObjectPtr res, PyObjectFromDataSlice(ds),
419423
arolla::python::SetPyErrFromStatus(_));
420424
return res.release();
421425
}
426+
if (obj_as_dict && output_class != Py_None) {
427+
arolla::python::SetPyErrFromStatus(absl::InvalidArgumentError(
428+
"obj_as_dict cannot be used with output_class"));
429+
return nullptr;
430+
}
422431
const DataSlice& schema = ds.GetSchema();
423432
DCHECK(bag != nullptr);
424433

425434
FlattenFallbackFinder fb_finder(*bag);
426435
const internal::DataBagImpl::FallbackSpan fallback_span =
427436
fb_finder.GetFlattenFallbacks();
428437
// We use original DataBag in ToPyVisitor.
429-
std::shared_ptr<ToPyVisitor> visitor = std::make_shared<ToPyVisitor>(
430-
obj_as_dict, include_missing_attrs, leaf_ids, bag, fallback_span);
438+
std::shared_ptr<ToPyVisitor> visitor =
439+
std::make_shared<ToPyVisitor>(obj_as_dict, include_missing_attrs,
440+
leaf_ids, bag, fallback_span, output_class);
431441
// We use extracted DataBag for traversal.
432442
internal::Traverser<ToPyVisitor> traverse_op(bag->GetImpl(), fallback_span,
433443
visitor);
@@ -451,7 +461,8 @@ PyObject* absl_nullable ToPyImplInternal(
451461

452462
PyObject* absl_nullable ToPyImpl(const DataSlice& ds, DataBagPtr bag,
453463
int max_depth, bool obj_as_dict,
454-
bool include_missing_attrs) {
464+
bool include_missing_attrs,
465+
PyObject* output_class) {
455466
// When `max_depth != -1`, we want objects/dicts/lists at `max_depth`
456467
// to be kept as DataItems and not converted to Python objects.
457468
// To do that, we extract a DataSlice, and keep track of the leaf DataItems.
@@ -474,10 +485,11 @@ PyObject* absl_nullable ToPyImpl(const DataSlice& ds, DataBagPtr bag,
474485
/*casting_callback=*/std::nullopt, std::move(leaf_callback)),
475486
arolla::python::SetPyErrFromStatus(_));
476487
return ToPyImplInternal(extracted_ds, ds.GetBag(), obj_as_dict,
477-
include_missing_attrs, objects_not_to_convert);
488+
include_missing_attrs, objects_not_to_convert,
489+
output_class);
478490
}
479491
return ToPyImplInternal(ds, nullptr, obj_as_dict, include_missing_attrs,
480-
objects_not_to_convert);
492+
objects_not_to_convert, output_class);
481493
}
482494

483495
} // namespace
@@ -487,9 +499,9 @@ PyObject* absl_nullable PyDataSlice_to_py(PyObject* self,
487499
Py_ssize_t nargs) {
488500
arolla::python::DCheckPyGIL();
489501
arolla::python::PyCancellationScope cancellation_scope;
490-
if (nargs != 3) {
502+
if (nargs != 4) {
491503
PyErr_Format(PyExc_ValueError,
492-
"DataSlice._to_py_impl accepts exactly 3 arguments, got %d",
504+
"DataSlice._to_py_impl accepts exactly 4 arguments, got %d",
493505
nargs);
494506
return nullptr;
495507
}
@@ -518,9 +530,15 @@ PyObject* absl_nullable PyDataSlice_to_py(PyObject* self,
518530
return nullptr;
519531
}
520532
const bool include_missing_attrs = py_args[2] == Py_True;
533+
PyObject* output_class = py_args[3];
534+
if (output_class == nullptr) {
535+
PyErr_Format(PyExc_TypeError,
536+
"expecting output_class to be a class, got nullptr");
537+
return nullptr;
538+
}
521539

522540
return ToPyImpl(ds, ds.GetBag(), max_depth, obj_as_dict,
523-
include_missing_attrs);
541+
include_missing_attrs, output_class);
524542
}
525543

526544
} // namespace koladata::python

py/koladata/functions/py_conversions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def to_py(
7070
max_depth: int = 2,
7171
obj_as_dict: bool = False,
7272
include_missing_attrs: bool = True,
73+
output_class: type[Any] | None = None,
7374
) -> Any:
7475
"""Returns a readable python object from a DataSlice.
7576
@@ -84,9 +85,10 @@ def to_py(
8485
are converted to automatically constructed 'Obj' dataclass instances.
8586
include_missing_attrs: whether to include attributes with None value in
8687
objects.
88+
output_class: If not None, will be used recursively as the output type.
8789
"""
8890
return ds._to_py_impl( # pylint: disable=protected-access
89-
max_depth, obj_as_dict, include_missing_attrs
91+
max_depth, obj_as_dict, include_missing_attrs, output_class
9092
)
9193

9294

py/koladata/types/data_slice.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1020,7 +1020,7 @@ If `schema` is set, that schema is used, otherwise the schema is inferred from
10201020
"Returns an UNSPECIFIED with DataSlice QType."},
10211021
{"_to_py_impl", (PyCFunction)PyDataSlice_to_py, METH_FASTCALL,
10221022
"_to_py_impl(ds, /, max_depth=-1, obj_as_dict=False, "
1023-
"include_missing_attrs=True)\n"
1023+
"include_missing_attrs=True, output_class=None)\n"
10241024
"--\n\n"
10251025
"Returns a Python object equivalent to this DataSlice.\n"},
10261026
{"internal_as_py", PyDataSlice_internal_as_py, METH_NOARGS,

py/koladata/types/data_slice.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -498,12 +498,14 @@ def _explode(
498498
return _eval_op('kd.explode', self, ndim)
499499

500500

501+
# TODO(b/483020237) make arguments keyword-only.
501502
@add_method(DataSlice, 'to_py')
502503
def to_py(
503504
ds: DataSlice,
504505
max_depth: int = 2,
505506
obj_as_dict: bool = False,
506507
include_missing_attrs: bool = True,
508+
output_class: Any | None = None,
507509
) -> Any:
508510
"""Returns a readable python object from a DataSlice.
509511
@@ -517,15 +519,18 @@ def to_py(
517519
are converted to automatically constructed 'Obj' dataclass instances.
518520
include_missing_attrs: whether to include attributes with None value in
519521
objects.
522+
output_class: If not None, will be used recursively as the output type.
520523
"""
521524
return ds._to_py_impl( # pylint: disable=protected-access
522-
max_depth, obj_as_dict, include_missing_attrs
525+
max_depth, obj_as_dict, include_missing_attrs, output_class
523526
)
524527

525528

526529
@add_method(DataSlice, 'to_pytree')
527530
def to_pytree(
528-
ds: DataSlice, max_depth: int = 2, include_missing_attrs: bool = True
531+
ds: DataSlice,
532+
max_depth: int = 2,
533+
include_missing_attrs: bool = True,
529534
) -> Any:
530535
"""Returns a readable python object from a DataSlice.
531536
@@ -542,7 +547,7 @@ def to_pytree(
542547
objects.
543548
"""
544549
return ds._to_py_impl( # pylint: disable=protected-access
545-
max_depth, True, include_missing_attrs
550+
max_depth, True, include_missing_attrs, None
546551
)
547552

548553

py/koladata/types/type_defs.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,13 @@ class DataSlice(arolla.QValue):
152152
def with_name(self, name: Any) -> Self: ...
153153

154154
def to_py(
155-
self, max_depth: int = 2, obj_as_dict: bool = False, include_missing_attrs: bool = True
155+
self, max_depth: int = 2, obj_as_dict: bool = False, include_missing_attrs: bool = True, output_class: Any | None = None
156156
) -> Any: ...
157157
def to_pytree(
158158
self, max_depth: int = 2, include_missing_attrs: bool = True
159159
) -> Any: ...
160160
def _to_py_impl(
161-
self, max_depth: int = 2, obj_as_dict: bool = False, include_missing_attrs: bool = True
161+
self, max_depth: int = 2, obj_as_dict: bool = False, include_missing_attrs: bool = True, output_class: Any | None = None
162162
) -> Any: ...
163163
def _to_proto(self, message_class: Any) -> Any: ...
164164
def internal_as_py(self) -> Any: ...

0 commit comments

Comments
 (0)