Skip to content

feat(array): only consider arg position, NOT defaultness or keywardness, when binding lambdas in Array.map() and Array.filter() #11116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 18 additions & 56 deletions ibis/backends/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,11 +542,13 @@ def test_array_slice(backend, start, stop):
)
@pytest.mark.parametrize(
"func",
[lambda x: x + 1, partial(lambda x, y: x + y, y=1), ibis._ + 1],
ids=["lambda", "partial", "deferred"],
[
pytest.param(lambda x: x + 1, id="lambda"),
pytest.param(partial(lambda x, idx, y: x + y, y=1), id="partial"),
pytest.param(ibis._ + 1, id="deferred"),
],
)
def test_array_map(con, input, output, func):
t = ibis.memtable(input, schema=ibis.schema(dict(a="!array<int8>")))
t = ibis.memtable(input, schema=ibis.schema(dict(a="!array<int8>")))
expected = pd.Series(output["a"])

Expand Down Expand Up @@ -601,11 +603,12 @@ def test_array_map(con, input, output, func):
)
@pytest.mark.parametrize(
"func",
[lambda x, i: x + 1 + i, partial(lambda x, y, i: x + y + i, y=1)],
ids=["lambda", "partial"],
[
pytest.param(lambda x, i: x + 1 + i, id="lambda"),
pytest.param(partial(lambda x, i, y: x + y + i, y=1), id="partial"),
],
)
def test_array_map_with_index(con, input, output, func):
t = ibis.memtable(input, schema=ibis.schema(dict(a="!array<int8>")))
t = ibis.memtable(input, schema=ibis.schema(dict(a="!array<int8>")))
expected = pd.Series(output["a"])

Expand Down Expand Up @@ -649,8 +652,11 @@ def test_array_map_with_index(con, input, output, func):
)
@pytest.mark.parametrize(
"predicate",
[lambda x: x > 1, partial(lambda x, y: x > y, y=1), ibis._ > 1],
ids=["lambda", "partial", "deferred"],
[
pytest.param(lambda x: x > 1, id="lambda"),
pytest.param(partial(lambda x, i, y: x > y, y=1), id="partial"),
pytest.param(ibis._ > 1, id=" deferred"),
],
)
def test_array_filter(con, input, output, predicate):
t = ibis.memtable(input, schema=ibis.schema(dict(a="!array<int8>")))
Expand Down Expand Up @@ -696,8 +702,10 @@ def test_array_filter(con, input, output, predicate):
)
@pytest.mark.parametrize(
"predicate",
[lambda x, i: x + (i - i) > 1, partial(lambda x, y, i: x > y + (i * 0), y=1)],
ids=["lambda", "partial"],
[
pytest.param(lambda x, i: x + (i - i) > 1, id="lambda"),
pytest.param(partial(lambda x, i, y: x > y + (i * 0), y=1), id="partial"),
],
)
def test_array_filter_with_index(con, input, output, predicate):
t = ibis.memtable(input, schema=ibis.schema(dict(a="!array<int8>")))
Expand All @@ -710,52 +718,6 @@ def test_array_filter_with_index(con, input, output, predicate):
)


@builtin_array
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a duplicate test of the above test_array_filter_with_index, that's why I'm deleting it.

@pytest.mark.notimpl(
["datafusion", "flink", "polars"], raises=com.OperationNotDefinedError
)
@pytest.mark.notimpl(["athena"], raises=PyAthenaDatabaseError)
@pytest.mark.notimpl(
["sqlite"], raises=com.UnsupportedBackendType, reason="Unsupported type: Array..."
)
@pytest.mark.parametrize(
("input", "output"),
[
param(
{"a": [[1, None, None], [4]]},
{"a": [[1, None], [4]]},
id="nulls",
marks=[
pytest.mark.notyet(
["bigquery"],
raises=GoogleBadRequest,
reason="NULLs are not allowed as array elements",
)
],
),
param({"a": [[1, 2], [1]]}, {"a": [[1], [1]]}, id="no_nulls"),
],
)
@pytest.mark.notyet(
"risingwave",
raises=PsycoPg2InternalError,
reason="no support for not null column constraint",
)
@pytest.mark.parametrize(
"predicate",
[lambda x, i: i % 2 == 0, partial(lambda x, y, i: i % 2 == 0, y=1)],
ids=["lambda", "partial"],
)
def test_array_filter_with_index_lambda(con, input, output, predicate):
t = ibis.memtable(input, schema=ibis.schema(dict(a="!array<int8>")))

expr = t.select(a=t.a.filter(predicate))
result = con.to_pyarrow(expr.a)
assert frozenset(map(tuple, result.to_pylist())) == frozenset(
map(tuple, output["a"])
)


@builtin_array
@pytest.mark.notimpl(
["athena"],
Expand Down
189 changes: 108 additions & 81 deletions ibis/expr/types/arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.common.annotations import EMPTY
from ibis.common.deferred import Deferred, deferrable
from ibis.expr.types.generic import Column, Scalar, Value

Expand Down Expand Up @@ -386,37 +385,45 @@ def join(self, sep: str | ir.StringValue, /) -> ir.StringValue:
return ops.ArrayStringJoin(self, sep=sep).to_expr()

def _construct_array_func_inputs(self, func):
shape = self.op().shape
if isinstance(func, Deferred):
name = "_"
index = None
resolve = func.resolve
elif callable(func):
names = (
key
for key, value in inspect.signature(func).parameters.items()
# arg is already bound
if value.default is EMPTY
positional_params = (
param
for param in inspect.signature(func).parameters.values()
if param.kind
in {
param.POSITIONAL_ONLY,
param.POSITIONAL_OR_KEYWORD,
param.VAR_POSITIONAL,
}
)
name = next(names)
index = next(names, None)
try:
element_param = next(positional_params)
except StopIteration:
raise TypeError("function must accept at least 1 positional argument")
name = element_param.name
if element_param.kind == element_param.VAR_POSITIONAL:
index = "index"
else:
index = getattr(next(positional_params, None), "name", None)
resolve = func
else:
raise TypeError(
f"function must be a Deferred or Callable, got `{type(func).__name__}`"
)

shape = self.op().shape
parameter = ops.Argument(name=name, shape=shape, dtype=self.type().value_type)

kwargs = {name: parameter.to_expr()}

value_arg = ops.Argument(name=name, shape=shape, dtype=self.type().value_type)
args = [value_arg.to_expr()]
if index is not None:
index_arg = ops.Argument(name=index, shape=shape, dtype=dt.int64)
kwargs[index] = index_arg.to_expr()
index = index_arg
index = ops.Argument(name=index, shape=shape, dtype=dt.int64)
args.append(index.to_expr())

body = resolve(**kwargs)
return parameter, index, body
body = resolve(*args)
return value_arg, index, body

def map(
self,
Expand All @@ -432,9 +439,9 @@ def map(
func
Function or `Deferred` to apply to each element of this array.

Callables must accept one or two arguments. If there are two
arguments, the second argument is the **zero**-based index of each
element of the array.
Callables will be called as `func(element)` or `func(element, idx)`
depending on if the function accepts 1 or 2+ positional parameters.
The `idx` argument is the **zero**-based index of each element in the array.

Returns
-------
Expand Down Expand Up @@ -484,23 +491,22 @@ def map(
│ [] │
└────────────────────────────────────────────┘

`.map()` also supports more complex callables like `functools.partial`
and `lambda`s with closures
You can optionally include a second index argument in the mapped function

>>> t.a.map(lambda x, i: i % 2)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayMap(a, Modulus(i, 2), x, i) ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int64> │
├──────────────────────────────────┤
│ [0, 1, ... +1] │
│ [0] │
│ [] │
└──────────────────────────────────┘

`.map()` also supports more complex callables like `lambda`s with closures,
`functools.partial`s, and `func(*args)`

>>> from functools import partial
>>> def add(x, y):
... return x + y
>>> add2 = partial(add, y=2)
>>> t.a.map(add2)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayMap(a, Add(x, 2), x) ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int64> │
├───────────────────────────┤
│ [3, None, ... +1] │
│ [6] │
│ [] │
└───────────────────────────┘
>>> y = 2
>>> t.a.map(lambda x: x + y)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
Expand All @@ -513,18 +519,29 @@ def map(
│ [] │
└───────────────────────────┘

You can optionally include a second index argument in the mapped function

>>> t.a.map(lambda x, i: i % 2)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayMap(a, Modulus(i, 2), x, i) ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int64> │
├──────────────────────────────────┤
│ [0, 1, ... +1] │
│ [0] │
│ [] │
└──────────────────────────────────┘
>>> from functools import partial
>>> def add(x, i, y):
... return x + i + y
>>> t.a.map(partial(add, y=2))
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayMap(a, Add(Add(x, i), 2), x, i) ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int64> │
├──────────────────────────────────────┤
│ [3, None, ... +1] │
│ [6] │
│ [] │
└──────────────────────────────────────┘
>>> t.a.map(lambda *elem_and_idx: elem_and_idx[0] + elem_and_idx[1])
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayMap(a, Add(elem_and_idx, index), elem_and_idx, index) ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int64> │
├────────────────────────────────────────────────────────────┤
│ [1, None, ... +1] │
│ [4] │
│ [] │
└────────────────────────────────────────────────────────────┘
"""
param, index, body = self._construct_array_func_inputs(func)
return ops.ArrayMap(self, param=param, index=index, body=body).to_expr()
Expand All @@ -543,9 +560,9 @@ def filter(
predicate
Function or `Deferred` to use to filter array elements.

Callables must accept one or two arguments. If there are two
arguments, the second argument is the **zero**-based index of each
element of the array.
Callables will be called as `func(element)` or `func(element, idx)`
depending on if the function accepts 1 or 2+ positional parameters.
The `idx` argument is the **zero**-based index of each element in the array.

Returns
-------
Expand Down Expand Up @@ -595,35 +612,6 @@ def filter(
│ [] │
└──────────────────────────────────┘

`.filter()` also supports more complex callables like `functools.partial`
and `lambda`s with closures

>>> from functools import partial
>>> def gt(x, y):
... return x > y
>>> gt1 = partial(gt, y=1)
>>> t.a.filter(gt1)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayFilter(a, Greater(x, 1), x) ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int64> │
├──────────────────────────────────┤
│ [2] │
│ [4] │
│ [] │
└──────────────────────────────────┘
>>> y = 1
>>> t.a.filter(lambda x: x > y)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayFilter(a, Greater(x, 1), x) ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int64> │
├──────────────────────────────────┤
│ [2] │
│ [4] │
│ [] │
└──────────────────────────────────┘

You can optionally include a second index argument in the predicate function

>>> t.a.filter(lambda x, i: i % 4 == 0)
Expand All @@ -636,6 +624,45 @@ def filter(
│ [4] │
│ [] │
└────────────────────────────────────────────────┘

`.filter()` also supports more complex callables like `lambda`s with closures,
`functools.partial`s, and `func(*args)`

>>> y = 1
>>> t.a.filter(lambda x: x > y)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayFilter(a, Greater(x, 1), x) ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int64> │
├──────────────────────────────────┤
│ [2] │
│ [4] │
│ [] │
└──────────────────────────────────┘
>>> from functools import partial
>>> def gt(x, i, y):
... return x > y
>>> gt1 = partial(gt, y=1)
>>> t.a.filter(gt1)
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayFilter(a, Greater(x, 1), x, i) ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int64> │
├─────────────────────────────────────┤
│ [2] │
│ [4] │
│ [] │
└─────────────────────────────────────┘
>>> t.a.filter(lambda *elem_and_idx: elem_and_idx[0] > elem_and_idx[1])
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ArrayFilter(a, Greater(elem_and_idx, index), elem_and_idx, index) ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ array<int64> │
├───────────────────────────────────────────────────────────────────┤
│ [1] │
│ [4] │
│ [] │
└───────────────────────────────────────────────────────────────────┘
"""
param, index, body = self._construct_array_func_inputs(predicate)
return ops.ArrayFilter(self, param=param, index=index, body=body).to_expr()
Expand Down
Loading