diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 7a2f2dcfe..620f13682 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -4,6 +4,7 @@ import numbers import operator import os +import weakref from collections import defaultdict from collections.abc import Generator, Mapping @@ -34,6 +35,20 @@ no_default = "__no_default__" +_object_cache = weakref.WeakValueDictionary() + + +def normalize_arg(arg): + if isinstance(arg, list): + return tuple(arg) + if isinstance(arg, dict): + return tuple(sorted(arg.items())) + if isinstance(arg, pd.core.base.PandasObject): + return (type(arg), id(arg)) # not quite safe + if isinstance(arg, np.ndarray): + return (type(arg), id(arg)) # not quite safe + return arg + class Expr: """Primary class for all Expressions @@ -46,7 +61,26 @@ class Expr: _defaults = {} _is_length_preserving = False - def __init__(self, *args, **kwargs): + def __new__(cls, *args, **kwargs): + key = ( + cls, + tuple(map(normalize_arg, args)), + tuple(sorted(toolz.valmap(normalize_arg, kwargs).items())), + ) + + try: + return _object_cache[key] + except KeyError: + obj = object.__new__(cls) + cls._init(obj, *args, **kwargs) + _object_cache[key] = obj + return obj + except Exception: # can not hash + obj = object.__new__(cls) + cls._init(obj, *args, **kwargs) + return obj + + def _init(self, *args, **kwargs): operands = list(args) for parameter in type(self)._parameters[len(operands) :]: try: diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index 13a76e849..e5f7b1d43 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -1212,3 +1212,11 @@ def test_shape(df, pdf): def test_size(df, pdf): assert_eq(df.size, pdf.size) + + +def test_object_caching(df): + a = df + 1 + b = df + 1 + assert a._expr is b._expr + assert a._meta is b._meta + del a, b