Skip to content

Commit d768ea3

Browse files
gforsythcpcloud
authored andcommitted
feat: add ops.Contains -> singular_or_list translation
1 parent ddf2e21 commit d768ea3

File tree

4 files changed

+47
-2
lines changed

4 files changed

+47
-2
lines changed

ibis_substrait/compiler/decompile.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,15 @@ def decompile_if_then(
756756
) -> ir.ValueExpr:
757757
return decompile(if_then, children, offsets, decompiler)
758758

759+
@staticmethod
760+
def decompile_singular_or_list(
761+
singular_or_list: stalg.Expression.SingularOrList,
762+
children: Sequence[ir.TableExpr],
763+
offsets: Sequence[int],
764+
decompiler: SubstraitDecompiler,
765+
) -> ir.ValueExpr:
766+
return decompile(singular_or_list, children, offsets, decompiler)
767+
759768

760769
@decompile.register
761770
def _decompile_expression(
@@ -852,6 +861,19 @@ def _decompile_expression_if_then(
852861
return base_case
853862

854863

864+
@decompile.register
865+
def _decompile_expression_singular_or_list(
866+
msg: stalg.Expression.SingularOrList,
867+
children: Sequence[ir.TableExpr],
868+
field_offsets: Sequence[int],
869+
decompiler: SubstraitDecompiler,
870+
) -> ir.ValueExpr:
871+
column = decompile(msg.value, children, field_offsets, decompiler)
872+
return column.isin(
873+
[decompile(value, children, field_offsets, decompiler) for value in msg.options]
874+
)
875+
876+
855877
class LiteralDecompiler:
856878
@staticmethod
857879
def decompile_boolean(value: bool) -> tuple[bool, dt.Boolean]:

ibis_substrait/compiler/mapping.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
"Any": "any",
55
"Between": "between",
66
"Cast": "cast",
7-
"Contains": "contains",
87
"Count": "count",
98
"CountDistinct": "countdistinct",
109
"Divide": "/", # wrong but using for duckdb compatibility right now
@@ -26,7 +25,6 @@
2625
"Substring": "substring",
2726
"Subtract": "-", # wrong but using for duckdb compatibility right now
2827
"Sum": "sum",
29-
"ValueList": "values",
3028
}
3129

3230
SUBSTRAIT_IBIS_OP_MAPPING = {v: k for k, v in IBIS_SUBSTRAIT_OP_MAPPING.items()}

ibis_substrait/compiler/translate.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -928,3 +928,18 @@ def _simple_case(
928928
_else = {"else": translate(op.default, compiler, **kwargs)}
929929

930930
return stalg.Expression(if_then=stalg.Expression.IfThen(ifs=_ifs, **_else))
931+
932+
933+
@translate.register(ops.Contains)
934+
def _contains(
935+
op: ops.Contains,
936+
expr: ir.TableExpr,
937+
compiler: SubstraitCompiler,
938+
**kwargs: Any,
939+
) -> stalg.Expression:
940+
return stalg.Expression(
941+
singular_or_list=stalg.Expression.SingularOrList(
942+
value=translate(op.value, compiler, **kwargs),
943+
options=[translate(value, compiler, **kwargs) for value in op.options],
944+
)
945+
)

ibis_substrait/tests/compiler/test_decompiler.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,13 @@ def test_decompile_if_then(t, compiler):
228228
plan = compiler.compile(expr)
229229
(result,) = decompile(plan)
230230
assert result.equals(expr)
231+
232+
233+
def test_singular_or_list(compiler):
234+
t = ibis.table([("bork", dt.string)], name="t")
235+
236+
expr = t.filter(t.bork.isin(["ork", "bork"]))
237+
238+
plan = compiler.compile(expr)
239+
(result,) = decompile(plan)
240+
assert result.equals(expr)

0 commit comments

Comments
 (0)