Skip to content

Commit e631564

Browse files
apronchenkovcopybara-github
authored andcommitted
Expose include_missing parameter in kd.functor.map_py_fn()
PiperOrigin-RevId: 714125675 Change-Id: I7de628b5fbb1a46c1292f6c464ddfeb5f7d62bb0
1 parent 3e5f089 commit e631564

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

py/koladata/functor/functor_factories.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,7 @@ def map_py_fn(
435435
schema: Any = None,
436436
max_threads: Any = 1,
437437
ndim: Any = 0,
438+
include_missing: Any = None,
438439
**defaults: Any,
439440
) -> data_slice.DataSlice:
440441
"""Returns a Koda functor wrapping a python function for kd.map_py.
@@ -447,6 +448,9 @@ def map_py_fn(
447448
schema: The schema to use for resulting DataSlice.
448449
max_threads: maximum number of threads to use.
449450
ndim: Dimensionality of items to pass to `f`.
451+
include_missing: Specifies whether `f` should be applied to the missing
452+
items. By default, the function is applied to all items including the
453+
missing. `include_missing=False` can only be used with `ndim=0`.
450454
**defaults: Keyword defaults to pass to the function. The values in this map
451455
may be kde expressions, format strings, or 0-dim DataSlices. See the
452456
docstring for py_fn for more details.
@@ -459,7 +463,7 @@ def map_py_fn(
459463
schema=py_boxing.as_qvalue(schema),
460464
max_threads=py_boxing.as_qvalue(max_threads),
461465
ndim=py_boxing.as_qvalue(ndim),
462-
include_missing=py_boxing.as_qvalue(True),
466+
include_missing=py_boxing.as_qvalue(include_missing),
463467
item_completed_callback=py_boxing.as_qvalue(None),
464468
kwargs=I.kwargs,
465469
),

py/koladata/functor/functor_factories_test.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def test_fstr_fn_expr(self):
478478
functor_factories.fstr_fn(f'{kde.select(I.x, kdi.present):s}'),
479479
x=ds([1, None]),
480480
),
481-
ds(['1', None])
481+
ds(['1', None]),
482482
)
483483

484484
def test_fstr_fn_variable(self):
@@ -773,13 +773,27 @@ def fn(x):
773773

774774
def test_map_py_fn_with_missing_items(self):
775775
def fn(x):
776-
return -1 if x is None else x
776+
return -1 if x is None else x + 1
777777

778778
self.assertEqual(
779779
kd.call(
780780
functor_factories.map_py_fn(fn), x=kdi.slice([1, None, None])
781781
).to_py(),
782-
[1, -1, -1],
782+
[2, -1, -1],
783+
)
784+
self.assertEqual(
785+
kd.call(
786+
functor_factories.map_py_fn(fn, include_missing=True),
787+
x=kdi.slice([1, None, None]),
788+
).to_py(),
789+
[2, -1, -1],
790+
)
791+
self.assertEqual(
792+
kd.call(
793+
functor_factories.map_py_fn(fn, include_missing=False),
794+
x=kdi.slice([1, None, None]),
795+
).to_py(),
796+
[2, None, None],
783797
)
784798

785799
def test_map_py_default_arguments(self):

0 commit comments

Comments
 (0)