Skip to content

[Enh]: A richer Expr internal representation #2571

Open
@dangotbanned

Description

@dangotbanned

Related

This cuts across many issues/previous discussions, these are some I've seen pretty clear links to:

Issues

Description

Following (#2483 (comment)) I've been sneakily incubating this idea that I feel now has enough legs to share 😏.

The big picture is replacing ExprMetadata with a rich, node-based representation - that encodes every step of an Expr.
This might seem like a huge leap to take - but we're in a unique position where a battle-tested solution already exists 😎

So far I've been mapping out what this could look like on (oh-nodes).
I'm going to open a PR I've opened (#2572) so anyone feel free to comment on that or on this thread.

Notes

I wrote up a bit of guiding mantra of what I'm hoping the end result should look like:

  • Each Expr method should be representable by a single node
    • But the node does not need to be unique to the method
  • A chain of Expr methods should form a plan of operations
  • We must be able to enforce rules on what plans are permitted:
    • Must be flexible to both eager/lazy and individual backends
    • Must be flexible to a given context (select, with_columns, filter, group_by)
  • Nodes are:
    • Immutable, but
      • Can be extended/re-written at both the Narwhals & Compliant levels
    • Introspectable, but
      • Store as little-as-needed for the common case
      • Provide properties/methods for computing the less frequent metadata

Examples

As everything is represented by different classes - this one is a bit visual 😄

Column selections

from narwhals._plan import demo as nwd

>>> nwd.col("a"), nwd.col("b", "c"), nwd.nth(1), nwd.nth(3, 4, 5)
(Narwhals DummyExpr:
 col('a'),
 Narwhals DummyExpr:
 cols(['b', 'c']),
 Narwhals DummyExpr:
 nth(1),
 Narwhals DummyExpr:
 index_columns((3, 4, 5)))

Literals

We can discern the kind of literal that's wrapped:

import polars as pl

from narwhals._plan import demo as nwd
from narwhals._plan.dummy import DummySeries

series = DummySeries.from_native(pl.Series([1.1, 1.2]))
scalar = nwd.lit(5)
series = nwd.lit(series)
>>> scalar, series
(Narwhals DummyExpr:
 lit(int: 5),
 Narwhals DummyExpr:
 lit(Series))
>>> scalar._ir.is_scalar, series._ir.is_scalar
(True, False)

Funky

How about something more complex?

import narwhals as nw
from narwhals._plan import demo as nwd

>>> nwd.col("a").alias("b").cast(nw.Int8()).n_unique() + (nwd.col("c").count() * nwd.lit(10))
Narwhals DummyExpr:
[(col('a').alias('b').cast(Int8).n_unique()) + ([(col('c').count()) * (lit(int: 10))])]

Order dependence

Here I'm trying to enforce the rules from (#2528 (comment)).

The idea would be that we allow the last two variants in lazy backends:

import narwhals as nw
from narwhals._plan import demo as nwd

orderable_1 = nwd.col("a").alias("d").first()
orderable_2 = nwd.col("b").cast(nw.String()).last()
orderable_3 = nwd.col("c").sort_by(nwd.col("e")).first()
orderable_4 = nwd.col("d").last().over(order_by=nwd.col("e", "f", "g"))

The outputs include suggestions that use the actual Expr

orderable_1

>>> nwd.ensure_orderable_rules(orderable_1)
OrderDependentExprError: first() is order-dependent and requires an ordering operation for lazy backends.
Hint:
Instead of:
    col('a').alias('d').first()

If you want to aggregate to a single value, try:
    col('a').alias('d').sort_by(...).first()

Otherwise, try:
    col('a').alias('d').first().over(order_by=...)

orderable_2

>>> nwd.ensure_orderable_rules(orderable_2)
OrderDependentExprError: last() is order-dependent and requires an ordering operation for lazy backends.
Hint:
Instead of:
    col('b').cast(String).last()

If you want to aggregate to a single value, try:
    col('b').cast(String).sort_by(...).last()

Otherwise, try:
    col('b').cast(String).last().over(order_by=...)

orderable_3

>>> nwd.ensure_orderable_rules(orderable_3)[0]
Narwhals DummyExpr:
col('c').sort_by(by=(col('e'),), options=SortMultipleOptions(descending=[False], nulls_last=[False])).first()

orderable_4

>>> nwd.ensure_orderable_rules(orderable_4)[0]
Narwhals DummyExpr:
col('d').last().over(order_by=[cols(['e', 'f', 'g'])])

Is this just a fancy __repr__?

The repr is mostly a nice side effect of what is happening under the hood.
Take this example of a pretty complex expression:

import narwhals as nw
from narwhals._plan import demo as nwd

lhs = nwd.col("a").alias("b").cast(nw.Int8()).n_unique()
rhs = nwd.col("c").count() * nwd.lit(10)
result = (lhs + rhs).last().over(order_by=nwd.col("d", "e", "f"), descending=True)

Another option for introspection is via ExprIR.__str__

node = result._ir
>>> str(node)

This is the meat of the idea - as we've got something like what you'd see in ast

Raw output

"WindowExpr(expr=Last(expr=BinaryExpr(left=NUnique(expr=Cast(dtype=Int8, expr=Alias(expr=Column(name=a), name=b))), op=Add(), right=BinaryExpr(left=Count(expr=Column(name=c)), op=Multiply(), right=Literal(value=ScalarLiteral(dtype=Unknown, value=10))))), partition_by=(), order_by=((cols(['d', 'e', 'f']),), SortOptions(descending=True, nulls_last=False)), options=Over())"

Ruff'd output

WindowExpr(
    expr=Last(
        expr=BinaryExpr(
            left=NUnique(expr=Cast(dtype=Int8, expr=Alias(expr=Column(name=a), name=b))),
            op=Add(),
            right=BinaryExpr(
                left=Count(expr=Column(name=c)),
                op=Multiply(),
                right=Literal(value=ScalarLiteral(dtype=Unknown, value=10)),
            ),
        )
    ),
    partition_by=(),
    order_by=((cols(["d", "e", "f"]),), SortOptions(descending=True, nulls_last=False)),
    options=Over(),
)

References

Lots of links to polars source

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions