Skip to content

Expr as singleton #798

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 42 additions & 22 deletions dask_expr/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,39 @@ def _unpack_collections(o):
class Expr:
_parameters = []
_defaults = {}
_instances = weakref.WeakValueDictionary()

def __init__(self, *args, **kwargs):
def __new__(cls, *args, **kwargs):
operands = list(args)
for parameter in type(self)._parameters[len(operands) :]:
for parameter in cls._parameters[len(operands) :]:
try:
operands.append(kwargs.pop(parameter))
except KeyError:
operands.append(type(self)._defaults[parameter])
default = cls._defaults[parameter]
if callable(default):
default = default()
operands.append(default)
assert not kwargs, kwargs
operands = [_unpack_collections(o) for o in operands]
self.operands = operands
inst = object.__new__(cls)
inst.operands = [_unpack_collections(o) for o in operands]
_name = inst._name
if _name in Expr._instances:
return Expr._instances[_name]

Expr._instances[_name] = inst
return inst

def _tune_down(self):
return None

def _tune_up(self, parent):
return None

def _cull_down(self):
return None

def _cull_up(self, parent):
return None

def __str__(self):
s = ", ".join(
Expand Down Expand Up @@ -204,28 +226,26 @@ def rewrite(self, kind: str):
_continue = False

# Rewrite this node
if down_name in expr.__dir__():
out = getattr(expr, down_name)()
out = getattr(expr, down_name)()
if out is None:
out = expr
if not isinstance(out, Expr):
return out
if out._name != expr._name:
expr = out
continue

# Allow children to rewrite their parents
for child in expr.dependencies():
out = getattr(child, up_name)(expr)
if out is None:
out = expr
if not isinstance(out, Expr):
return out
if out._name != expr._name:
if out is not expr and out._name != expr._name:
expr = out
continue

# Allow children to rewrite their parents
for child in expr.dependencies():
if up_name in child.__dir__():
out = getattr(child, up_name)(expr)
if out is None:
out = expr
if not isinstance(out, Expr):
return out
if out is not expr and out._name != expr._name:
expr = out
_continue = True
break
_continue = True
break

if _continue:
continue
Expand Down
1 change: 1 addition & 0 deletions dask_expr/io/_delayed.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class _DelayedExpr(Expr):
# Wraps a Delayed object to make it an Expr for now. This is hacky and we should
# integrate this properly...
# TODO
_parameters = ["obj"]

def __init__(self, obj):
self.obj = obj
Expand Down
77 changes: 43 additions & 34 deletions dask_expr/io/parquet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import contextlib
import functools
import itertools
import operator
import warnings
Expand All @@ -26,8 +25,9 @@
from dask.dataframe.io.parquet.utils import _split_user_options
from dask.dataframe.io.utils import _is_local_fs
from dask.delayed import delayed
from dask.utils import apply, natural_sort_key, typename
from dask.utils import apply, funcname, natural_sort_key, typename
from fsspec.utils import stringify_path
from toolz import identity

from dask_expr._expr import (
EQ,
Expand All @@ -47,26 +47,15 @@
determine_column_projection,
)
from dask_expr._reductions import Len
from dask_expr._util import _convert_to_list
from dask_expr._util import _convert_to_list, _tokenize_deterministic
from dask_expr.io import BlockwiseIO, PartitionsFiltered

NONE_LABEL = "__null_dask_index__"

_cached_dataset_info = {}
_CACHED_DATASET_SIZE = 10
_CACHED_PLAN_SIZE = 10
_cached_plan = {}


def _control_cached_dataset_info(key):
if (
len(_cached_dataset_info) > _CACHED_DATASET_SIZE
and key not in _cached_dataset_info
):
key_to_pop = list(_cached_dataset_info.keys())[0]
_cached_dataset_info.pop(key_to_pop)


def _control_cached_plan(key):
if len(_cached_plan) > _CACHED_PLAN_SIZE and key not in _cached_plan:
key_to_pop = list(_cached_plan.keys())[0]
Expand Down Expand Up @@ -121,7 +110,7 @@ def _lower(self):
class ToParquetData(Blockwise):
_parameters = ToParquet._parameters

@cached_property
@property
def io_func(self):
return ToParquetFunctionWrapper(
self.engine,
Expand Down Expand Up @@ -257,7 +246,6 @@ def to_parquet(

# Clear read_parquet caches in case we are
# also reading from the overwritten path
_cached_dataset_info.clear()
_cached_plan.clear()

# Always skip divisions checks if divisions are unknown
Expand Down Expand Up @@ -383,11 +371,6 @@ def to_parquet(
if compute:
out = out.compute(**compute_kwargs)

# Invalidate the filesystem listing cache for the output path after write.
# We do this before returning, even if `compute=False`. This helps ensure
# that reading files that were just written succeeds.
fs.invalidate_cache(path)

return out


Expand All @@ -413,6 +396,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO):
"kwargs",
"_partitions",
"_series",
"_dataset_info_cache",
]
_defaults = {
"columns": None,
Expand All @@ -432,6 +416,7 @@ class ReadParquet(PartitionsFiltered, BlockwiseIO):
"kwargs": None,
"_partitions": None,
"_series": False,
"_dataset_info_cache": list,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new concept here is that I am moving off a global cache. The dataset_info is always calculated whenever a user calls read_parquet(foo) and will therefore always receive an accurate representation of the dataset at the time this is called.
This dataset_info is cached in this paramter. I am choosing a list as a container but this could be anything. I could also just set the operand and mutate the expression in place.

The benefit of using a paramter for this cache is that the cache will naturally propagate to all derived instances, e.g. whenever we rewrite the expression using Expr.substitute_parameters. This allows us to maintain the cache during optimization and it ties the lifetime of the cache to the lifetime of the expression ancestry removing any need for us to invalidate the cache ever.

}
_pq_length_stats = None
_absorb_projections = True
Expand Down Expand Up @@ -474,7 +459,21 @@ def _simplify_up(self, parent, dependents):
return Literal(sum(_lengths))

@cached_property
def _name(self):
return (
funcname(type(self)).lower()
+ "-"
+ _tokenize_deterministic(self.checksum, *self.operands)
)
Comment on lines +467 to +472
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this checksum is part of the _name allowing us to differentiate expressions that point to modified states of the dataset. It also allows us to reuse already cached "plans / divisions" if the dataset did not change which is the most common case


@property
def checksum(self):
return self._dataset_info["checksum"]

@property
def _dataset_info(self):
if rv := self.operand("_dataset_info_cache"):
return rv[0]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the ReadParquet expression is initialized first during a read_parquet this cache is empty and we'll fetch the dataset_info essentially during construction time of the expression object.

Subsequent expressions that are derived which are inheriting the cache will just access this making the __new__ call instantaneous.

# Process and split user options
(
dataset_options,
Expand Down Expand Up @@ -536,13 +535,20 @@ def _dataset_info(self):
**other_options,
},
)
dataset_token = tokenize(*args)
if dataset_token not in _cached_dataset_info:
_control_cached_dataset_info(dataset_token)
_cached_dataset_info[dataset_token] = self.engine._collect_dataset_info(
*args
)
dataset_info = _cached_dataset_info[dataset_token].copy()
dataset_info = self.engine._collect_dataset_info(*args)
checksum = []
files_for_checksum = []
if dataset_info["has_metadata_file"]:
files_for_checksum = [self.path + fs.sep + "_metadata"]
else:
files_for_checksum = dataset_info["ds"].files

for file in files_for_checksum:
# The checksum / file info is usually already cached by the fsspec
# FileSystem dir_cache since this info was already asked for in
# _collect_dataset_info
checksum.append(fs.checksum(file))
dataset_info["checksum"] = tokenize(checksum)
Comment on lines +556 to +561
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To deal with the cache consistency problem described in #800 I am calculating a checksum here. For s3 this falls back to using the ETag provided in the listdir response. This should not add any overhead since this stuff is already cached by fsspec.
We're either taking the checksum of the metadata file or of all files that we iterate over. At this point this listdir operation is already done so the checksum identifies every dataset uniquely. Since adding this checksum to the dataset_info, this also guarantees that the cache for the plan is invalidated if the dataset changes.


# Infer meta, accounting for index and columns arguments.
meta = self.engine._create_dd_meta(dataset_info)
Expand All @@ -558,6 +564,7 @@ def _dataset_info(self):
dataset_info["all_columns"] = all_columns
dataset_info["calculate_divisions"] = self.calculate_divisions

self._dataset_info_cache.append(dataset_info)
return dataset_info

@property
Expand All @@ -571,10 +578,10 @@ def _meta(self):
return meta[columns]
return meta

@cached_property
@property
def _io_func(self):
if self._plan["empty"]:
return lambda x: x
return identity
dataset_info = self._dataset_info
return ParquetFunctionWrapper(
self.engine,
Expand Down Expand Up @@ -662,7 +669,7 @@ def _update_length_statistics(self):
stat["num-rows"] for stat in _collect_pq_statistics(self)
)

@functools.cached_property
@property
def _fusion_compression_factor(self):
if self.operand("columns") is None:
return 1
Expand Down Expand Up @@ -767,9 +774,11 @@ def _maybe_list(val):
return [val]

return [
_maybe_list(val.to_list_tuple())
if hasattr(val, "to_list_tuple")
else _maybe_list(val)
(
_maybe_list(val.to_list_tuple())
if hasattr(val, "to_list_tuple")
else _maybe_list(val)
)
for val in self
]

Expand Down
3 changes: 2 additions & 1 deletion dask_expr/tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import dask.array as da
import numpy as np
import pytest
from dask.dataframe._compat import PANDAS_GE_210
from dask.dataframe._compat import PANDAS_GE_210, PANDAS_GE_220
from dask.dataframe.utils import UNKNOWN_CATEGORIES
from dask.utils import M

Expand Down Expand Up @@ -1035,6 +1035,7 @@ def test_head_down(df):
assert not isinstance(optimized.expr, expr.Head)


@pytest.mark.skipif(not PANDAS_GE_220, reason="not implemented")
def test_case_when(pdf, df):
result = df.x.case_when([(df.x.eq(1), 1), (df.y == 10, 2.5)])
expected = pdf.x.case_when([(pdf.x.eq(1), 1), (pdf.y == 10, 2.5)])
Expand Down