Skip to content

Commit f7e4931

Browse files
committed
feat: support empty arrays, improve ibis.array() API
Picking out the array stuff from #8666
1 parent 33ec754 commit f7e4931

File tree

13 files changed

+201
-54
lines changed

13 files changed

+201
-54
lines changed

ibis/backends/dask/executor.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
plan,
2929
)
3030
from ibis.common.exceptions import UnboundExpressionError, UnsupportedOperationError
31+
from ibis.formats.numpy import NumpyType
3132
from ibis.formats.pandas import PandasData, PandasType
3233
from ibis.util import gen_name
3334

@@ -155,9 +156,10 @@ def mapper(df, cases):
155156
return cls.partitionwise(mapper, kwargs, name=op.name, dtype=dtype)
156157

157158
@classmethod
158-
def visit(cls, op: ops.Array, exprs):
159+
def visit(cls, op: ops.Array, exprs, dtype):
160+
np_type = NumpyType.from_ibis(dtype)
159161
return cls.rowwise(
160-
lambda row: np.array(row, dtype=object), exprs, name=op.name, dtype=object
162+
lambda row: np.array(row, dtype=np_type), exprs, name=op.name, dtype=object
161163
)
162164

163165
@classmethod

ibis/backends/dask/helpers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def concat(cls, dfs, **kwargs):
3030

3131
@classmethod
3232
def asseries(cls, value, like=None):
33-
"""Ensure that value is a pandas Series object, broadcast if necessary."""
33+
"""Ensure that value is a dask Series object, broadcast if necessary."""
3434

3535
if isinstance(value, dd.Series):
3636
return value
@@ -50,7 +50,7 @@ def asseries(cls, value, like=None):
5050
elif isinstance(value, pd.Series):
5151
return dd.from_pandas(value, npartitions=1)
5252
elif like is not None:
53-
if isinstance(value, (tuple, list, dict)):
53+
if isinstance(value, (tuple, list, dict, np.ndarray)):
5454
fn = lambda df: pd.Series([value] * len(df), index=df.index)
5555
else:
5656
fn = lambda df: pd.Series(value, index=df.index)

ibis/backends/pandas/executor.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ def visit(cls, op: ops.Node, **kwargs):
4949

5050
@classmethod
5151
def visit(cls, op: ops.Literal, value, dtype):
52-
if dtype.is_interval():
52+
if value is None:
53+
value = None
54+
elif dtype.is_interval():
5355
value = pd.Timedelta(value, dtype.unit.short)
5456
elif dtype.is_array():
5557
value = np.array(value)
@@ -220,7 +222,7 @@ def visit(cls, op: ops.FindInSet, needle, values):
220222
return pd.Series(result, name=op.name)
221223

222224
@classmethod
223-
def visit(cls, op: ops.Array, exprs):
225+
def visit(cls, op: ops.Array, exprs, dtype):
224226
return cls.rowwise(lambda row: np.array(row, dtype=object), exprs)
225227

226228
@classmethod

ibis/backends/polars/compiler.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -87,25 +87,27 @@ def literal(op, **_):
8787
value = op.value
8888
dtype = op.dtype
8989

90-
if dtype.is_array():
91-
value = pl.Series("", value)
92-
typ = PolarsType.from_ibis(dtype)
93-
val = pl.lit(value, dtype=typ)
94-
return val.implode()
90+
# There are some interval types that _make_duration() can handle,
91+
# but PolarsType.from_ibis can't, so we need to handle them here.
92+
if dtype.is_interval():
93+
return _make_duration(value, dtype)
94+
95+
typ = PolarsType.from_ibis(dtype)
96+
if value is None:
97+
return pl.lit(None, dtype=typ)
98+
elif dtype.is_array():
99+
return pl.lit(pl.Series("", value).implode(), dtype=typ)
95100
elif dtype.is_struct():
96101
values = [
97102
pl.lit(v, dtype=PolarsType.from_ibis(dtype[k])).alias(k)
98103
for k, v in value.items()
99104
]
100105
return pl.struct(values)
101-
elif dtype.is_interval():
102-
return _make_duration(value, dtype)
103106
elif dtype.is_null():
104107
return pl.lit(value)
105108
elif dtype.is_binary():
106109
return pl.lit(value)
107110
else:
108-
typ = PolarsType.from_ibis(dtype)
109111
return pl.lit(op.value, dtype=typ)
110112

111113

@@ -974,9 +976,12 @@ def array_concat(op, **kw):
974976

975977

976978
@translate.register(ops.Array)
977-
def array_column(op, **kw):
978-
cols = [translate(col, **kw) for col in op.exprs]
979-
return pl.concat_list(cols)
979+
def array_literal(op, **kw):
980+
pdt = PolarsType.from_ibis(op.dtype)
981+
if op.exprs:
982+
return pl.concat_list([translate(col, **kw) for col in op.exprs]).cast(pdt)
983+
else:
984+
return pl.lit([], dtype=pdt)
980985

981986

982987
@translate.register(ops.ArrayCollect)

ibis/backends/sql/compiler.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -1019,8 +1019,20 @@ def visit_InSubquery(self, op, *, rel, needle):
10191019
query = sg.select(STAR).from_(query)
10201020
return needle.isin(query=query)
10211021

1022-
def visit_Array(self, op, *, exprs):
1023-
return self.f.array(*exprs)
1022+
def visit_Array(self, op, *, exprs, dtype):
1023+
if not exprs:
1024+
return self.cast(self.f.array(), dtype)
1025+
1026+
def maybe_cast(ibis_val, sg_expr):
1027+
if ibis_val.dtype == dtype.value_type:
1028+
return sg_expr
1029+
else:
1030+
return self.cast(sg_expr, dtype.value_type)
1031+
1032+
cast_exprs = [
1033+
maybe_cast(ibis_val, sg_expr) for ibis_val, sg_expr in zip(op.exprs, exprs)
1034+
]
1035+
return self.f.array(*cast_exprs)
10241036

10251037
def visit_StructColumn(self, op, *, names, values):
10261038
return sge.Struct.from_arg_list(

ibis/backends/tests/test_array.py

+80-5
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
PySparkAnalysisException,
3232
TrinoUserError,
3333
)
34+
from ibis.common.annotations import ValidationError
3435
from ibis.common.collections import frozendict
3536

3637
pytestmark = [
@@ -72,6 +73,85 @@
7273
# list.
7374

7475

76+
def test_array_factory(con):
77+
a = ibis.array([1, 2, 3])
78+
assert a.type() == dt.Array(value_type=dt.Int8)
79+
assert con.execute(a) == [1, 2, 3]
80+
81+
a2 = ibis.array(a)
82+
assert a.type() == dt.Array(value_type=dt.Int8)
83+
assert con.execute(a2) == [1, 2, 3]
84+
85+
86+
@pytest.mark.broken(
87+
["pandas", "dask"],
88+
raises=AssertionError,
89+
reason="results in [1, 2, 3]",
90+
)
91+
def test_array_factory_typed(con):
92+
typed = ibis.array([1, 2, 3], type="array<string>")
93+
assert con.execute(typed) == ["1", "2", "3"]
94+
95+
typed2 = ibis.array(ibis.array([1, 2, 3]), type="array<string>")
96+
assert con.execute(typed2) == ["1", "2", "3"]
97+
98+
99+
@pytest.mark.notimpl("flink", raises=Py4JJavaError)
100+
@pytest.mark.notimpl(["pandas", "dask"], raises=ValueError)
101+
def test_array_factory_empty(con):
102+
with pytest.raises(ValidationError):
103+
ibis.array([])
104+
105+
empty_typed = ibis.array([], type="array<string>")
106+
assert empty_typed.type() == dt.Array(value_type=dt.string)
107+
assert con.execute(empty_typed) == []
108+
109+
110+
@pytest.mark.notyet(
111+
"clickhouse", raises=ClickHouseDatabaseError, reason="nested types can't be NULL"
112+
)
113+
@pytest.mark.notyet(
114+
"flink", raises=Py4JJavaError, reason="Parameters must be of the same type"
115+
)
116+
def test_array_factory_null(con):
117+
with pytest.raises(ValidationError):
118+
ibis.array(None)
119+
with pytest.raises(ValidationError):
120+
ibis.array(None, type="int64")
121+
none_typed = ibis.array(None, type="array<string>")
122+
assert none_typed.type() == dt.Array(value_type=dt.string)
123+
assert con.execute(none_typed) is None
124+
125+
nones = ibis.array([None, None], type="array<string>")
126+
assert nones.type() == dt.Array(value_type=dt.string)
127+
assert con.execute(nones) == [None, None]
128+
129+
# Execute a real value here, so the backends that don't support arrays
130+
# actually xfail as we expect them to.
131+
# Otherwise would have to @mark.xfail every test in this file besides this one.
132+
assert con.execute(ibis.array([1, 2])) == [1, 2]
133+
134+
135+
@pytest.mark.broken(
136+
["datafusion", "flink", "polars"],
137+
raises=AssertionError,
138+
reason="[None, 1] executes to [np.nan, 1.0]",
139+
)
140+
@pytest.mark.broken(
141+
["pandas", "dask"],
142+
raises=AssertionError,
143+
reason="even with explicit cast, results in [None, 1]",
144+
)
145+
def test_array_factory_null_mixed(con):
146+
none_and_val = ibis.array([None, 1])
147+
assert none_and_val.type() == dt.Array(value_type=dt.Int8)
148+
assert con.execute(none_and_val) == [None, 1]
149+
150+
none_and_val_typed = ibis.array([None, 1], type="array<string>")
151+
assert none_and_val_typed.type() == dt.Array(value_type=dt.String)
152+
assert con.execute(none_and_val_typed) == [None, "1"]
153+
154+
75155
def test_array_column(backend, alltypes, df):
76156
expr = ibis.array(
77157
[alltypes["double_col"], alltypes["double_col"], 5.0, ibis.literal(6.0)]
@@ -1354,11 +1434,6 @@ def test_unnest_range(con):
13541434
id="array",
13551435
marks=[
13561436
pytest.mark.notyet(["bigquery"], raises=GoogleBadRequest),
1357-
pytest.mark.broken(
1358-
["polars"],
1359-
reason="expression input not supported with nested arrays",
1360-
raises=TypeError,
1361-
),
13621437
],
13631438
),
13641439
],

ibis/backends/tests/test_generic.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1431,13 +1431,12 @@ def query(t, group_cols):
14311431
snapshot.assert_match(str(ibis.to_sql(t3, dialect=con.name)), "out.sql")
14321432

14331433

1434-
@pytest.mark.notimpl(["oracle", "exasol"], raises=com.OperationNotDefinedError)
1435-
@pytest.mark.notimpl(["druid"], raises=AssertionError)
14361434
@pytest.mark.notyet(
1437-
["datafusion", "impala", "mssql", "mysql", "sqlite"],
1435+
["datafusion", "exasol", "impala", "mssql", "mysql", "oracle", "sqlite"],
14381436
reason="backend doesn't support arrays and we don't implement pivot_longer with unions yet",
1439-
raises=com.OperationNotDefinedError,
1437+
raises=(com.OperationNotDefinedError, com.UnsupportedBackendType),
14401438
)
1439+
@pytest.mark.notimpl(["druid"], raises=AssertionError)
14411440
@pytest.mark.broken(
14421441
["trino"],
14431442
reason="invalid code generated for unnesting a struct",

ibis/backends/tests/test_sql.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
ibis.array([432]),
2020
marks=[
2121
pytest.mark.never(
22-
["mysql", "mssql", "oracle", "impala", "sqlite"],
22+
["exasol", "mysql", "mssql", "oracle", "impala", "sqlite"],
2323
raises=(exc.OperationNotDefinedError, exc.UnsupportedBackendType),
2424
reason="arrays not supported in the backend",
2525
),
@@ -30,8 +30,18 @@
3030
ibis.struct(dict(abc=432)),
3131
marks=[
3232
pytest.mark.never(
33-
["impala", "mysql", "sqlite", "mssql", "exasol"],
34-
raises=(NotImplementedError, exc.UnsupportedBackendType),
33+
[
34+
"exasol",
35+
"impala",
36+
"mysql",
37+
"sqlite",
38+
"mssql",
39+
],
40+
raises=(
41+
exc.OperationNotDefinedError,
42+
NotImplementedError,
43+
exc.UnsupportedBackendType,
44+
),
3545
reason="structs not supported in the backend",
3646
),
3747
pytest.mark.notimpl(
@@ -104,7 +114,7 @@ def test_isin_bug(con, snapshot):
104114
@pytest.mark.notyet(
105115
["datafusion", "exasol", "oracle", "flink", "risingwave"],
106116
reason="no unnest support",
107-
raises=exc.OperationNotDefinedError,
117+
raises=(exc.OperationNotDefinedError, exc.UnsupportedBackendType),
108118
)
109119
@pytest.mark.notyet(
110120
["sqlite", "mysql", "druid", "impala", "mssql"], reason="no unnest support upstream"

ibis/backends/tests/test_string.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
ClickHouseDatabaseError,
1717
OracleDatabaseError,
1818
PsycoPg2InternalError,
19+
PyDruidProgrammingError,
1920
PyODBCProgrammingError,
2021
)
2122
from ibis.common.annotations import ValidationError
@@ -835,21 +836,26 @@ def test_capitalize(con, inp, expected):
835836
assert pd.isnull(result)
836837

837838

839+
@pytest.mark.never(
840+
["exasol", "impala", "mssql", "mysql", "sqlite"],
841+
reason="Backend doesn't support arrays",
842+
raises=(com.OperationNotDefinedError, com.UnsupportedBackendType),
843+
)
838844
@pytest.mark.notimpl(
839845
[
840846
"dask",
841847
"pandas",
842848
"polars",
843849
"oracle",
844850
"flink",
845-
"sqlite",
846-
"mssql",
847-
"mysql",
848-
"exasol",
849-
"impala",
850851
],
851852
raises=com.OperationNotDefinedError,
852853
)
854+
@pytest.mark.broken(
855+
"druid",
856+
raises=PyDruidProgrammingError,
857+
reason="ibis.array() has a cast, and we compile the dtype to 'VARCHAR[] instead of 'ARRAY<STRING>' as needed",
858+
)
853859
def test_array_string_join(con):
854860
s = ibis.array(["a", "b", "c"])
855861
expected = "a,b,c"

ibis/expr/operations/arrays.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,15 @@ class Array(Value):
1919
"""Construct an array."""
2020

2121
exprs: VarTuple[Value]
22+
dtype: Optional[dt.Array] = None
2223

23-
@attribute
24-
def shape(self):
25-
return rlz.highest_precedence_shape(self.exprs)
24+
shape = rlz.shape_like("exprs")
2625

27-
@attribute
28-
def dtype(self):
29-
return dt.Array(rlz.highest_precedence_dtype(self.exprs))
26+
def __init__(self, exprs, dtype: dt.Array | None = None):
27+
# If len(exprs) == 0, the caller is responsible for providing a dtype
28+
if dtype is None:
29+
dtype = dt.Array(rlz.highest_precedence_dtype(exprs))
30+
super().__init__(exprs=exprs, dtype=dtype)
3031

3132

3233
@public

ibis/expr/rules.py

+4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from public import public
77

8+
import ibis.expr.datashape as ds
89
import ibis.expr.datatypes as dt
910
import ibis.expr.operations as ops
1011
from ibis import util
@@ -16,6 +17,9 @@
1617

1718
@public
1819
def highest_precedence_shape(nodes):
20+
nodes = tuple(nodes)
21+
if len(nodes) == 0:
22+
return ds.scalar
1923
return max(node.shape for node in nodes)
2024

2125

Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
DummyTable
2-
foo: Array([1])
2+
foo: Array(exprs=[1], dtype=array<int8>)

0 commit comments

Comments
 (0)