Skip to content

[DNM/RFC] Track dependents without weakrefs #831

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
75 changes: 49 additions & 26 deletions dask_expr/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import functools
import os
import weakref
from collections import defaultdict
from collections.abc import Generator

Expand All @@ -11,34 +10,59 @@
import toolz
from dask.dataframe.core import is_dataframe_like, is_index_like, is_series_like
from dask.utils import funcname, import_required, is_arraylike
from toolz.dicttoolz import merge

from dask_expr._util import _BackendData, _tokenize_deterministic


def _unpack_collections(o):
def _unpack_expr(o):
if isinstance(o, Expr):
return o

if hasattr(o, "expr"):
return o.expr
return o, o._name
elif hasattr(o, "expr"):
return o.expr, o.expr._name
else:
return o
return o, None


class Expr:
_parameters = []
_defaults = {}

def __init__(self, *args, **kwargs):
self._dependencies = {}
operands = list(args)
for parameter in type(self)._parameters[len(operands) :]:
try:
operands.append(kwargs.pop(parameter))
except KeyError:
operands.append(type(self)._defaults[parameter])
assert not kwargs, kwargs
operands = [_unpack_collections(o) for o in operands]
self.operands = operands
parsed_operands = []
children = set()
_subgraphs = []
_subgraph_instances = []
_graph_instances = {}
for o in operands:
expr, name = _unpack_expr(o)
parsed_operands.append(expr)
if name is not None:
children.add(name)
_subgraphs.append(expr._graph)
_subgraph_instances.append(expr._graph_instances)
_graph_instances[name] = expr

self.operands = parsed_operands
name = self._name
# Graph instances is a mapping name -> Expr instance
# Graph itself is a mapping of dependencies mapping names to a set of names
self._graph_instances = merge(_graph_instances, *_subgraph_instances)
self._graph = merge(*_subgraphs)
self._graph[name] = children
# Probably a bad idea to have a self ref
self._graph_instances[name] = self

def __hash__(self):
raise TypeError("Don't!")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Funny" story: The above implementation is maintaining two dicts. One that maps names and dependencies and one that just maps names to the actual objects. Initially I just tried to use the Expr objects themselves in the mappings using a mapping from Expr -> set[Expr}. That triggered funny recursion errors and it took me a while to understand that...

Whenever there is a hash collision, the implementation of set falls back to use __eq__ to compare the new object with the one that is stored under a given hash, i.e. object.__eq__(old, new). However, this doesn't evaluate to a bool but defines another expression... 💥

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, definite __eq__ is pretty hazardous in Python. It would be reasonable, I think, to drop Expr.__eq__ and use explicit Eq(a, b) calls instead, and then rely on Frame.__eq__ for user comfort.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or rather, use whatever class is defined in collections.py to hold the __eq__ method


def __str__(self):
s = ", ".join(
Expand Down Expand Up @@ -131,6 +155,21 @@ def dependencies(self):
# Dependencies are `Expr` operands only
return [operand for operand in self.operands if isinstance(operand, Expr)]

@functools.cached_property
def _dependent_graph(self):
rv = defaultdict(set)
# This should be O(E)
for expr, dependencies in self._graph.items():
rv[expr]
for dep in dependencies:
rv[dep].add(expr)
for name, exprs in rv.items():
rv[name] = {self._graph_instances[e] for e in exprs}
return rv

def dependents(self):
return self._dependent_graph

def _task(self, index: int):
"""The task for the i'th partition

Expand Down Expand Up @@ -312,7 +351,7 @@ def simplify_once(self, dependents: defaultdict):
def simplify(self) -> Expr:
expr = self
while True:
dependents = collect_depdendents(expr)
dependents = expr.dependents()
new = expr.simplify_once(dependents=dependents)
if new._name == expr._name:
break
Expand Down Expand Up @@ -678,19 +717,3 @@ def find_operations(self, operation: type | tuple[type]) -> Generator[Expr]:
or issubclass(operation, Expr)
), "`operation` must be`Expr` subclass)"
return (expr for expr in self.walk() if isinstance(expr, operation))


def collect_depdendents(expr) -> defaultdict:
dependents = defaultdict(list)
stack = [expr]
seen = set()
while stack:
node = stack.pop()
if node._name in seen:
continue
seen.add(node._name)

for dep in node.dependencies():
stack.append(dep)
dependents[dep._name].append(weakref.ref(node))
return dependents
4 changes: 2 additions & 2 deletions dask_expr/_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1186,7 +1186,7 @@ def _simplify_up(self, parent, dependents):
if self._name != parent.frame._name:
# We can't push the filter through the filter condition
return
parents = [x() for x in dependents[self._name] if x() is not None]
parents = dependents[self._name]
if not all(isinstance(p, Filter) for p in parents):
return
return type(self)(
Expand Down Expand Up @@ -3171,7 +3171,7 @@ def determine_column_projection(expr, parent, dependents, additional_columns=Non
column_union = []
else:
column_union = parent.columns.copy()
parents = [x() for x in dependents[expr._name] if x() is not None]
parents = dependents[expr._name]

for p in parents:
if len(p.columns) > 0:
Expand Down
5 changes: 1 addition & 4 deletions dask_expr/io/_delayed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@ class _DelayedExpr(Expr):
# Wraps a Delayed object to make it an Expr for now. This is hacky and we should
# integrate this properly...
# TODO

def __init__(self, obj):
self.obj = obj
self.operands = [obj]
_parameters = ["obj"]

@property
def _name(self):
Expand Down