Skip to content

Commit 2fe7f26

Browse files
committed
feat: add support for ops.Extract<span>
Covers `ExtractYear`, `ExtractMonth`, and `ExtractDay`. Substrait also allows `SECONDS` as the specified argument but I'm not clear on exactly what they're looking for there. Epoch Seconds?
1 parent 308d33a commit 2fe7f26

File tree

4 files changed

+68
-3
lines changed

4 files changed

+68
-3
lines changed

ibis_substrait/compiler/decompile.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,7 @@ def _decompile_expression_aggregate_function(
570570
) -> ir.ValueExpr:
571571
extension = decompiler.function_extensions[aggregate_function.function_reference]
572572
function_name = extension.name
573-
op_type = getattr(ops, SUBSTRAIT_IBIS_OP_MAPPING[function_name])
573+
op_type = SUBSTRAIT_IBIS_OP_MAPPING[function_name]
574574
args = [
575575
decompile(arg, children, field_offsets, decompiler)
576576
for arg in aggregate_function.args
@@ -774,6 +774,15 @@ def decompile_cast(
774774
) -> ir.ValueExpr:
775775
return decompile(cast, children, offsets, decompiler)
776776

777+
@staticmethod
778+
def decompile_enum(
779+
enum: stalg.Expression.Enum,
780+
children: Sequence[ir.TableExpr],
781+
offsets: Sequence[int],
782+
decompiler: SubstraitDecompiler,
783+
) -> ir.ValueExpr:
784+
return decompile(enum)
785+
777786

778787
@decompile.register
779788
def _decompile_expression(
@@ -800,7 +809,7 @@ def _decompile_expression_scalar_function(
800809
) -> ir.ValueExpr:
801810
extension = decompiler.function_extensions[msg.function_reference]
802811
function_name = extension.name
803-
op_type = getattr(ops, SUBSTRAIT_IBIS_OP_MAPPING[function_name])
812+
op_type = SUBSTRAIT_IBIS_OP_MAPPING[function_name]
804813
args = [decompile(arg, children, field_offsets, decompiler) for arg in msg.args]
805814
expr = op_type(*args).to_expr()
806815
output_type = _decompile_type(msg.output_type)
@@ -923,6 +932,13 @@ def _decompile_expression_cast(
923932
return column.cast(_decompile_type(msg.type))
924933

925934

935+
@decompile.register
936+
def _decompile_expression_enum(
937+
msg: stalg.Expression.Enum,
938+
) -> str:
939+
return msg.specified
940+
941+
926942
class LiteralDecompiler:
927943
@staticmethod
928944
def decompile_boolean(value: bool) -> tuple[bool, dt.Boolean]:

ibis_substrait/compiler/mapping.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import ibis.expr.operations as ops
2+
13
IBIS_SUBSTRAIT_OP_MAPPING = {
24
"Add": "add",
35
"And": "and",
@@ -7,6 +9,9 @@
79
"CountDistinct": "countdistinct",
810
"Divide": "divide",
911
"Equals": "equal",
12+
"ExtractYear": "extract",
13+
"ExtractMonth": "extract",
14+
"ExtractDay": "extract",
1015
"Greater": "gt",
1116
"GreaterEqual": "gte",
1217
"Less": "lt",
@@ -25,4 +30,10 @@
2530
"Sum": "sum",
2631
}
2732

28-
SUBSTRAIT_IBIS_OP_MAPPING = {v: k for k, v in IBIS_SUBSTRAIT_OP_MAPPING.items()}
33+
SUBSTRAIT_IBIS_OP_MAPPING = {
34+
v: getattr(ops, k) for k, v in IBIS_SUBSTRAIT_OP_MAPPING.items()
35+
}
36+
# override when reversing many-to-one mappings
37+
SUBSTRAIT_IBIS_OP_MAPPING["extract"] = lambda table, span: getattr(
38+
ops, f"Extract{span.capitalize()}"
39+
)(table)

ibis_substrait/compiler/translate.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -959,3 +959,25 @@ def _cast(
959959
type=translate(op.to), input=translate(op.arg, compiler, **kwargs)
960960
)
961961
)
962+
963+
964+
@translate.register(ops.ExtractDateField)
965+
def _extractdatefield(
966+
op: ops.ExtractDateField,
967+
expr: ir.TableExpr,
968+
compiler: SubstraitCompiler,
969+
**kwargs: Any,
970+
) -> stalg.Expression:
971+
scalar_func = stalg.Expression.ScalarFunction(
972+
function_reference=compiler.function_id(expr),
973+
output_type=translate(expr.type()),
974+
args=[
975+
translate(arg, compiler, **kwargs)
976+
for arg in op.args
977+
if isinstance(arg, ir.Expr)
978+
],
979+
)
980+
# e.g. "ExtractYear" -> "YEAR"
981+
span = type(op).__name__.lstrip("Extract").upper()
982+
scalar_func.args.add(enum=stalg.Expression.Enum(specified=span))
983+
return stalg.Expression(scalar_function=scalar_func)

ibis_substrait/tests/compiler/test_decompiler.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,3 +265,19 @@ def test_searchedcase(compiler):
265265
plan = compiler.compile(expr)
266266
(result,) = decompile(plan)
267267
assert result.equals(expr)
268+
269+
270+
@pytest.mark.parametrize(
271+
"span",
272+
[
273+
"year",
274+
"month",
275+
"day",
276+
],
277+
)
278+
def test_extract_date(compiler, span):
279+
t = ibis.table([("o_orderdate", dt.date)], name="t")
280+
expr = t[getattr(t.o_orderdate, span)()]
281+
plan = compiler.compile(expr)
282+
(result,) = decompile(plan)
283+
assert result.equals(expr)

0 commit comments

Comments
 (0)