Skip to content

Commit f5128e5

Browse files
committed
lit lit lit
1 parent b606d6c commit f5128e5

File tree

4 files changed

+19
-7
lines changed

4 files changed

+19
-7
lines changed

daft/expressions/expressions.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,11 @@ def lit(value: object) -> Expression:
110110
assert isinstance(exponent, int)
111111
lit_value = _decimal_lit(sign == 1, digits, exponent)
112112
elif isinstance(value, Series):
113-
lit_value = _series_lit(value._series)
113+
series = Series.from_pylist([value]).cast(DataType.fixed_size_list(value.datatype(), len(value)))
114+
lit_value = _series_lit(series._series)
115+
elif isinstance(value, list):
116+
list_of_exprs = [Expression._to_expression(item) for item in value]
117+
lit_value = list_(*list_of_exprs)._expr
114118
else:
115119
lit_value = _lit(value)
116120
return Expression._from_pyexpr(lit_value)
@@ -1457,7 +1461,7 @@ def is_in(self, other: Any) -> Expression:
14571461
other = [Expression._to_expression(item) for item in other]
14581462
elif not isinstance(other, Expression):
14591463
series = item_to_series("items", other)
1460-
other = [Expression._to_expression(series)]
1464+
other = [Expression._from_pyexpr(_series_lit(series._series))]
14611465
else:
14621466
other = [other]
14631467

@@ -4339,7 +4343,7 @@ def count_matches(
43394343
patterns = [patterns]
43404344
if not isinstance(patterns, Expression):
43414345
series = item_to_series("items", patterns)
4342-
patterns = Expression._to_expression(series)
4346+
patterns = Expression._from_pyexpr(_series_lit(series._series))
43434347

43444348
whole_words_expr = Expression._to_expression(whole_words)._expr
43454349
case_sensitive_expr = Expression._to_expression(case_sensitive)._expr

tests/expressions/test_expressions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
(b"a", DataType.binary()),
2626
(True, DataType.bool()),
2727
(None, DataType.null()),
28-
(Series.from_pylist([1, 2, 3]), DataType.int64()),
28+
(Series.from_pylist([1, 2, 3]), DataType.fixed_size_list(DataType.int64(), 3)),
2929
(date(2023, 1, 1), DataType.date()),
3030
(time(1, 2, 3, 4), DataType.time(timeunit=TimeUnit.from_str("us"))),
3131
(datetime(2023, 1, 1), DataType.timestamp(timeunit=TimeUnit.from_str("us"))),
@@ -641,7 +641,7 @@ def test_duration_lit(input, expected) -> None:
641641
def test_repr_series_lit() -> None:
642642
s = lit(Series.from_pylist([1, 2, 3]))
643643
output = repr(s)
644-
assert output == "lit([1, 2, 3])"
644+
assert output == "lit([[1, 2, 3]])"
645645

646646

647647
def test_list_value_counts():

tests/expressions/test_null_safe_equals.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import pyarrow as pa
44
import pytest
55

6+
from daft.daft import series_lit
67
from daft.expressions import col, lit
8+
from daft.expressions.expressions import Expression
79
from daft.recordbatch import MicroPartition
810

911

@@ -356,7 +358,13 @@ def test_length_mismatch_all_types(type_name, left_data, right_data):
356358
right_table = MicroPartition.from_pydict({"value": right_data})
357359

358360
with pytest.raises(ValueError) as exc_info:
359-
result = left_table.eval_expression_list([col("value").eq_null_safe(right_table.get_column_by_name("value"))])
361+
result = left_table.eval_expression_list(
362+
[
363+
col("value").eq_null_safe(
364+
Expression._from_pyexpr(series_lit(right_table.get_column_by_name("value")._series))
365+
)
366+
]
367+
)
360368
# Force evaluation by accessing the result
361369
result.get_column_by_name("value").to_pylist()
362370

tests/expressions/typing/test_float.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_fill_nan(binary_data_fixture):
3636
lhs, rhs = binary_data_fixture
3737
assert_typing_resolve_vs_runtime_behavior(
3838
data=binary_data_fixture,
39-
expr=col(lhs.name()).float.fill_nan(rhs),
39+
expr=col(lhs.name()).float.fill_nan(col(rhs.name())),
4040
run_kernel=lambda: lhs.float.fill_nan(rhs),
4141
resolvable=(
4242
lhs.datatype() in (DataType.float32(), DataType.float64())

0 commit comments

Comments
 (0)