Skip to content

Commit a55b1c6

Browse files
committed
feat: support empty arrays, improve ibis.array() API
Picking out the array stuff from #8666
1 parent 33ec754 commit a55b1c6

File tree

17 files changed

+194
-64
lines changed

17 files changed

+194
-64
lines changed

ibis/backends/dask/executor.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
plan,
2929
)
3030
from ibis.common.exceptions import UnboundExpressionError, UnsupportedOperationError
31+
from ibis.formats.numpy import NumpyType
3132
from ibis.formats.pandas import PandasData, PandasType
3233
from ibis.util import gen_name
3334

@@ -155,9 +156,10 @@ def mapper(df, cases):
155156
return cls.partitionwise(mapper, kwargs, name=op.name, dtype=dtype)
156157

157158
@classmethod
158-
def visit(cls, op: ops.Array, exprs):
159+
def visit(cls, op: ops.Array, exprs, dtype):
160+
np_type = NumpyType.from_ibis(dtype)
159161
return cls.rowwise(
160-
lambda row: np.array(row, dtype=object), exprs, name=op.name, dtype=object
162+
lambda row: np.array(row, dtype=np_type), exprs, name=op.name, dtype=object
161163
)
162164

163165
@classmethod

ibis/backends/dask/helpers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def concat(cls, dfs, **kwargs):
3030

3131
@classmethod
3232
def asseries(cls, value, like=None):
33-
"""Ensure that value is a pandas Series object, broadcast if necessary."""
33+
"""Ensure that value is a dask Series object, broadcast if necessary."""
3434

3535
if isinstance(value, dd.Series):
3636
return value
@@ -50,7 +50,7 @@ def asseries(cls, value, like=None):
5050
elif isinstance(value, pd.Series):
5151
return dd.from_pandas(value, npartitions=1)
5252
elif like is not None:
53-
if isinstance(value, (tuple, list, dict)):
53+
if isinstance(value, (tuple, list, dict, np.ndarray)):
5454
fn = lambda df: pd.Series([value] * len(df), index=df.index)
5555
else:
5656
fn = lambda df: pd.Series(value, index=df.index)

ibis/backends/pandas/executor.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ def visit(cls, op: ops.Node, **kwargs):
4949

5050
@classmethod
5151
def visit(cls, op: ops.Literal, value, dtype):
52-
if dtype.is_interval():
52+
if value is None:
53+
value = None
54+
elif dtype.is_interval():
5355
value = pd.Timedelta(value, dtype.unit.short)
5456
elif dtype.is_array():
5557
value = np.array(value)
@@ -220,7 +222,7 @@ def visit(cls, op: ops.FindInSet, needle, values):
220222
return pd.Series(result, name=op.name)
221223

222224
@classmethod
223-
def visit(cls, op: ops.Array, exprs):
225+
def visit(cls, op: ops.Array, exprs, dtype):
224226
return cls.rowwise(lambda row: np.array(row, dtype=object), exprs)
225227

226228
@classmethod

ibis/backends/polars/compiler.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -87,25 +87,27 @@ def literal(op, **_):
8787
value = op.value
8888
dtype = op.dtype
8989

90-
if dtype.is_array():
91-
value = pl.Series("", value)
92-
typ = PolarsType.from_ibis(dtype)
93-
val = pl.lit(value, dtype=typ)
94-
return val.implode()
90+
# There are some interval types that _make_duration() can handle,
91+
# but PolarsType.from_ibis can't, so we need to handle them here.
92+
if dtype.is_interval():
93+
return _make_duration(value, dtype)
94+
95+
typ = PolarsType.from_ibis(dtype)
96+
if value is None:
97+
return pl.lit(None, dtype=typ)
98+
elif dtype.is_array():
99+
return pl.lit(pl.Series("", value).implode(), dtype=typ)
95100
elif dtype.is_struct():
96101
values = [
97102
pl.lit(v, dtype=PolarsType.from_ibis(dtype[k])).alias(k)
98103
for k, v in value.items()
99104
]
100105
return pl.struct(values)
101-
elif dtype.is_interval():
102-
return _make_duration(value, dtype)
103106
elif dtype.is_null():
104107
return pl.lit(value)
105108
elif dtype.is_binary():
106109
return pl.lit(value)
107110
else:
108-
typ = PolarsType.from_ibis(dtype)
109111
return pl.lit(op.value, dtype=typ)
110112

111113

@@ -974,9 +976,12 @@ def array_concat(op, **kw):
974976

975977

976978
@translate.register(ops.Array)
977-
def array_column(op, **kw):
978-
cols = [translate(col, **kw) for col in op.exprs]
979-
return pl.concat_list(cols)
979+
def array_literal(op, **kw):
980+
pdt = PolarsType.from_ibis(op.dtype)
981+
if op.exprs:
982+
return pl.concat_list([translate(col, **kw) for col in op.exprs]).cast(pdt)
983+
else:
984+
return pl.lit([], dtype=pdt)
980985

981986

982987
@translate.register(ops.ArrayCollect)

ibis/backends/sql/compiler.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1019,8 +1019,8 @@ def visit_InSubquery(self, op, *, rel, needle):
10191019
query = sg.select(STAR).from_(query)
10201020
return needle.isin(query=query)
10211021

1022-
def visit_Array(self, op, *, exprs):
1023-
return self.f.array(*exprs)
1022+
def visit_Array(self, op, *, exprs, dtype):
1023+
return self.cast(self.f.array(*exprs), dtype)
10241024

10251025
def visit_StructColumn(self, op, *, names, values):
10261026
return sge.Struct.from_arg_list(

ibis/backends/tests/snapshots/test_sql/test_union_aliasing/clickhouse/out.sql

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ WITH "t5" AS (
2626
SELECT
2727
"t0"."field_of_study",
2828
arrayJoin(
29-
[
29+
CAST([
3030
CAST(tuple('1970-71', "t0"."1970-71") AS Tuple("years" Nullable(String), "degrees" Nullable(Int64))),
3131
CAST(tuple('1975-76', "t0"."1975-76") AS Tuple("years" Nullable(String), "degrees" Nullable(Int64))),
3232
CAST(tuple('1980-81', "t0"."1980-81") AS Tuple("years" Nullable(String), "degrees" Nullable(Int64))),
@@ -45,7 +45,7 @@ WITH "t5" AS (
4545
CAST(tuple('2017-18', "t0"."2017-18") AS Tuple("years" Nullable(String), "degrees" Nullable(Int64))),
4646
CAST(tuple('2018-19', "t0"."2018-19") AS Tuple("years" Nullable(String), "degrees" Nullable(Int64))),
4747
CAST(tuple('2019-20', "t0"."2019-20") AS Tuple("years" Nullable(String), "degrees" Nullable(Int64)))
48-
]
48+
] AS Array(Tuple("years" Nullable(String), "degrees" Nullable(Int64))))
4949
) AS "__pivoted__"
5050
FROM "humanities" AS "t0"
5151
) AS "t1"

ibis/backends/tests/snapshots/test_sql/test_union_aliasing/duckdb/out.sql

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ WITH "t5" AS (
2626
SELECT
2727
"t0"."field_of_study",
2828
UNNEST(
29-
[
29+
CAST([
3030
{'years': '1970-71', 'degrees': "t0"."1970-71"},
3131
{'years': '1975-76', 'degrees': "t0"."1975-76"},
3232
{'years': '1980-81', 'degrees': "t0"."1980-81"},
@@ -45,7 +45,7 @@ WITH "t5" AS (
4545
{'years': '2017-18', 'degrees': "t0"."2017-18"},
4646
{'years': '2018-19', 'degrees': "t0"."2018-19"},
4747
{'years': '2019-20', 'degrees': "t0"."2019-20"}
48-
]
48+
] AS STRUCT("years" TEXT, "degrees" BIGINT)[])
4949
) AS "__pivoted__"
5050
FROM "humanities" AS "t0"
5151
) AS "t1"

ibis/backends/tests/snapshots/test_sql/test_union_aliasing/postgres/out.sql

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ WITH "t5" AS (
2626
SELECT
2727
"t0"."field_of_study",
2828
UNNEST(
29-
ARRAY[ROW(CAST('1970-71' AS VARCHAR), CAST("t0"."1970-71" AS BIGINT)), ROW(CAST('1975-76' AS VARCHAR), CAST("t0"."1975-76" AS BIGINT)), ROW(CAST('1980-81' AS VARCHAR), CAST("t0"."1980-81" AS BIGINT)), ROW(CAST('1985-86' AS VARCHAR), CAST("t0"."1985-86" AS BIGINT)), ROW(CAST('1990-91' AS VARCHAR), CAST("t0"."1990-91" AS BIGINT)), ROW(CAST('1995-96' AS VARCHAR), CAST("t0"."1995-96" AS BIGINT)), ROW(CAST('2000-01' AS VARCHAR), CAST("t0"."2000-01" AS BIGINT)), ROW(CAST('2005-06' AS VARCHAR), CAST("t0"."2005-06" AS BIGINT)), ROW(CAST('2010-11' AS VARCHAR), CAST("t0"."2010-11" AS BIGINT)), ROW(CAST('2011-12' AS VARCHAR), CAST("t0"."2011-12" AS BIGINT)), ROW(CAST('2012-13' AS VARCHAR), CAST("t0"."2012-13" AS BIGINT)), ROW(CAST('2013-14' AS VARCHAR), CAST("t0"."2013-14" AS BIGINT)), ROW(CAST('2014-15' AS VARCHAR), CAST("t0"."2014-15" AS BIGINT)), ROW(CAST('2015-16' AS VARCHAR), CAST("t0"."2015-16" AS BIGINT)), ROW(CAST('2016-17' AS VARCHAR), CAST("t0"."2016-17" AS BIGINT)), ROW(CAST('2017-18' AS VARCHAR), CAST("t0"."2017-18" AS BIGINT)), ROW(CAST('2018-19' AS VARCHAR), CAST("t0"."2018-19" AS BIGINT)), ROW(CAST('2019-20' AS VARCHAR), CAST("t0"."2019-20" AS BIGINT))]
29+
CAST(ARRAY[ROW(CAST('1970-71' AS VARCHAR), CAST("t0"."1970-71" AS BIGINT)), ROW(CAST('1975-76' AS VARCHAR), CAST("t0"."1975-76" AS BIGINT)), ROW(CAST('1980-81' AS VARCHAR), CAST("t0"."1980-81" AS BIGINT)), ROW(CAST('1985-86' AS VARCHAR), CAST("t0"."1985-86" AS BIGINT)), ROW(CAST('1990-91' AS VARCHAR), CAST("t0"."1990-91" AS BIGINT)), ROW(CAST('1995-96' AS VARCHAR), CAST("t0"."1995-96" AS BIGINT)), ROW(CAST('2000-01' AS VARCHAR), CAST("t0"."2000-01" AS BIGINT)), ROW(CAST('2005-06' AS VARCHAR), CAST("t0"."2005-06" AS BIGINT)), ROW(CAST('2010-11' AS VARCHAR), CAST("t0"."2010-11" AS BIGINT)), ROW(CAST('2011-12' AS VARCHAR), CAST("t0"."2011-12" AS BIGINT)), ROW(CAST('2012-13' AS VARCHAR), CAST("t0"."2012-13" AS BIGINT)), ROW(CAST('2013-14' AS VARCHAR), CAST("t0"."2013-14" AS BIGINT)), ROW(CAST('2014-15' AS VARCHAR), CAST("t0"."2014-15" AS BIGINT)), ROW(CAST('2015-16' AS VARCHAR), CAST("t0"."2015-16" AS BIGINT)), ROW(CAST('2016-17' AS VARCHAR), CAST("t0"."2016-17" AS BIGINT)), ROW(CAST('2017-18' AS VARCHAR), CAST("t0"."2017-18" AS BIGINT)), ROW(CAST('2018-19' AS VARCHAR), CAST("t0"."2018-19" AS BIGINT)), ROW(CAST('2019-20' AS VARCHAR), CAST("t0"."2019-20" AS BIGINT))] AS STRUCT<"years" VARCHAR, "degrees" BIGINT>[])
3030
) AS "__pivoted__"
3131
FROM "humanities" AS "t0"
3232
) AS "t1"

0 commit comments

Comments
 (0)