Skip to content

Commit b782c23

Browse files
committed
feat(array): only consider argument position, NOT defaultness, when binding lambdas in Array.map() and Array.filter()
1 parent 21e47ac commit b782c23

File tree

2 files changed

+24
-12
lines changed

2 files changed

+24
-12
lines changed

ibis/expr/types/arrays.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import ibis.expr.datatypes as dt
99
import ibis.expr.operations as ops
10-
from ibis.common.annotations import EMPTY
1110
from ibis.common.deferred import Deferred, deferrable
1211
from ibis.expr.types.generic import Column, Scalar, Value
1312

@@ -391,12 +390,7 @@ def _construct_array_func_inputs(self, func):
391390
index = None
392391
resolve = func.resolve
393392
elif callable(func):
394-
names = (
395-
key
396-
for key, value in inspect.signature(func).parameters.items()
397-
# arg is already bound
398-
if value.default is EMPTY
399-
)
393+
names = (name for name in inspect.signature(func).parameters.keys())
400394
name = next(names)
401395
index = next(names, None)
402396
resolve = func

ibis/tests/expr/test_value_exprs.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1537,11 +1537,20 @@ def test_array_map():
15371537

15381538
r1 = arr.map(_ * 2)
15391539
r2 = arr.map(lambda x: x * 2.0)
1540-
r3 = arr.map(functools.partial(lambda a, b: a + b, b=2))
1540+
r3 = arr.map(lambda x=2: x * 2.0)
1541+
r4 = arr.map(lambda a, idx: a + idx)
1542+
r5 = arr.map(functools.partial(lambda a, idx: a + idx, idx=2))
1543+
r6 = arr.map(functools.partial(lambda a, idx, c: a + c, c=2))
15411544

15421545
assert r1.type() == dt.Array(dt.int16)
15431546
assert r2.type() == dt.Array(dt.float64)
1544-
assert r3.type() == dt.Array(dt.int16)
1547+
assert r3.type() == dt.Array(dt.float64)
1548+
assert r4.type() == dt.Array(dt.int64)
1549+
assert r5.type() == dt.Array(dt.int64)
1550+
assert r6.type() == dt.Array(dt.int16)
1551+
1552+
with pytest.raises(TypeError, match="missing 1 required positional argument"):
1553+
arr.map(lambda a, idx, c: a + 2)
15451554

15461555
with pytest.raises(TypeError, match="must be a Deferred or Callable"):
15471556
# Non-deferred expressions aren't allowed
@@ -1551,13 +1560,22 @@ def test_array_map():
15511560
def test_array_filter():
15521561
arr = ibis.array([1, 2, 3])
15531562

1554-
r1 = arr.filter(lambda x: x < 0)
1555-
r2 = arr.filter(_ < 0)
1556-
r3 = arr.filter(functools.partial(lambda a, b: a == b, b=2))
1563+
r1 = arr.filter(_ < 0)
1564+
r2 = arr.filter(lambda x: x < 0)
1565+
r3 = arr.filter(lambda x=4: x < 0)
1566+
r4 = arr.filter(lambda x, idx: x < idx)
1567+
r5 = arr.filter(functools.partial(lambda a, idx: a == idx, idx=2))
1568+
r6 = arr.filter(functools.partial(lambda a, idx, c: a == c, c=2))
15571569

15581570
assert r1.type() == arr.type()
15591571
assert r2.type() == arr.type()
15601572
assert r3.type() == arr.type()
1573+
assert r4.type() == arr.type()
1574+
assert r5.type() == arr.type()
1575+
assert r6.type() == arr.type()
1576+
1577+
with pytest.raises(TypeError, match="missing 1 required positional argument"):
1578+
arr.filter(lambda a, idx, c: a + 2)
15611579

15621580
with pytest.raises(TypeError, match="must be a Deferred or Callable"):
15631581
# Non-deferred expressions aren't allowed

0 commit comments

Comments
 (0)