diff --git a/mesmerize_core/caiman_extensions/_utils.py b/mesmerize_core/caiman_extensions/_utils.py index f422ac8..7dfbe3e 100644 --- a/mesmerize_core/caiman_extensions/_utils.py +++ b/mesmerize_core/caiman_extensions/_utils.py @@ -1,7 +1,9 @@ -from functools import wraps -from typing import Union +from typing import Optional, Union, Protocol from uuid import UUID +import pandas as pd + +from mesmerize_core.utils import wrapsmethod from mesmerize_core.caiman_extensions._batch_exceptions import ( BatchItemNotRunError, BatchItemUnsuccessfulError, @@ -10,9 +12,14 @@ ) -def validate(algo: str = None): +class SeriesExtensions(Protocol): + """Common interface for series accessors to help with type hinting""" + _series: pd.Series + + +def validate(algo: Optional[str] = None): def dec(func): - @wraps(func) + @wrapsmethod(func) def wrapper(self, *args, **kwargs): if self._series["outputs"] is None: raise BatchItemNotRunError("Item has not been run") @@ -38,7 +45,7 @@ def wrapper(self, *args, **kwargs): def _verify_and_lock_batch_file(func): """Acquires lock and ensures batch file has the same items as current df before calling wrapped function""" - @wraps(func) + @wrapsmethod(func) def wrapper(instance, *args, **kwargs): with instance._batch_lock: disk_df = instance.reload_from_disk() @@ -53,7 +60,7 @@ def wrapper(instance, *args, **kwargs): def _index_parser(func): - @wraps(func) + @wrapsmethod(func) def _parser(instance, *args, **kwargs): if "index" in kwargs.keys(): index: Union[int, str, UUID] = kwargs["index"] diff --git a/mesmerize_core/caiman_extensions/cache.py b/mesmerize_core/caiman_extensions/cache.py index 62dbd44..47053ea 100644 --- a/mesmerize_core/caiman_extensions/cache.py +++ b/mesmerize_core/caiman_extensions/cache.py @@ -1,17 +1,22 @@ -from functools import wraps -from typing import Union, Optional +import inspect +from typing import Union, Optional, TypeVar import pandas as pd import time import numpy as np import sys from caiman.source_extraction.cnmf import CNMF -import re -from sys import getsizeof import copy +from ..utils import wrapsmethod +from ._utils import SeriesExtensions -def _check_arg_equality(args, cache_args): + +# return type of decorated method +R = TypeVar("R") + + +def _check_arg_equality(args, cache_args) -> bool: if not type(args) == type(cache_args): return False if isinstance(cache_args, np.ndarray): @@ -20,31 +25,54 @@ def _check_arg_equality(args, cache_args): return cache_args == args -def _check_args_equality(args, cache_args): +def _check_args_equality(args, cache_args) -> bool: if len(args) != len(cache_args): return False - equality = list() + if isinstance(args, tuple): for arg, cache_arg in zip(args, cache_args): - equality.append(_check_arg_equality(arg, cache_arg)) + if not _check_arg_equality(arg, cache_arg): + return False else: - for k in args.keys(): - equality.append(_check_arg_equality(args[k], cache_args[k])) - return all(equality) + for k, v in args.items(): + if k not in cache_args or not _check_arg_equality(v, cache_args[k]): + return False + return True -def _return_wrapper(output, copy_bool): +def _return_wrapper(output: R, copy_bool: bool) -> R: if copy_bool == True: return copy.deepcopy(output) else: return output +def _get_item_size(item) -> int: + """Recursively compute size of return value""" + if isinstance(item, np.ndarray): + return item.data.nbytes + + elif isinstance(item, (tuple, list)): + size = 0 + for entry in item: + size += _get_item_size(entry) + return size + + elif isinstance(item, CNMF): + size = 0 + for attr in item.estimates.__dict__.values(): + size += _get_item_size(attr) + return size + + else: + return sys.getsizeof(item) + + class Cache: def __init__(self, cache_size: Optional[Union[int, str]] = None): self.cache = pd.DataFrame( data=None, - columns=["uuid", "function", "args", "kwargs", "return_val", "time_stamp"], + columns=["uuid", "function", "kwargs", "return_val", "time_stamp", "added_time", "bytes"], ) self.set_maxsize(cache_size) @@ -52,10 +80,10 @@ def get_cache(self): return self.cache def clear_cache(self): - while len(self.cache.index) != 0: + while len(self.cache) != 0: self.cache.drop(index=self.cache.index[-1], axis=0, inplace=True) - def set_maxsize(self, max_size: Union[int, str]): + def set_maxsize(self, max_size: Optional[Union[int, str]]): if max_size is None: self.storage_type = "RAM" self.size = 1024**3 @@ -70,68 +98,83 @@ def set_maxsize(self, max_size: Union[int, str]): self.size = max_size def _get_cache_size_bytes(self): - """Returns in bytes""" - cache_size = 0 - for i in range(len(self.cache.index)): - if isinstance(self.cache.iloc[i, 4], np.ndarray): - cache_size += self.cache.iloc[i, 4].data.nbytes - elif isinstance(self.cache.iloc[i, 4], (tuple, list)): - for lists in self.cache.iloc[i, 4]: - for array in lists: - cache_size += array.data.nbytes - elif isinstance(self.cache.iloc[i, 4], CNMF): - sizes = list() - for attr in self.cache.iloc[i, 4].estimates.__dict__.values(): - if isinstance(attr, np.ndarray): - sizes.append(attr.data.nbytes) - else: - sizes.append(getsizeof(attr)) - else: - cache_size += sys.getsizeof(self.cache.iloc[i, 4]) - - return cache_size + return self.cache.loc[:, "bytes"].sum() def use_cache(self, func): - @wraps(func) - def _use_cache(instance, *args, **kwargs): - if "return_copy" in kwargs.keys(): - return_copy = kwargs["return_copy"] - else: - return_copy = True - - if self.size == 0: + """ + Caching decorator. + + Usage: + + .. code-block:: python + @cache.use_cache + def my_costly_method(self, *, return_copy=True): + ... + + return_copy determines whether an entry that is found in the cache is copied before it is returned. + The decorated function *must* take return_copy as a keyword-only paramter, and this will be read by the decorator. + """ + # get default value of return_copy from function signature + params = inspect.signature(func).parameters + return_copy_arg = params.get("return_copy") + if return_copy_arg is None: + raise TypeError("return_copy must be in wrapped function signature when not provided to decorator") + elif return_copy_arg.kind != inspect.Parameter.KEYWORD_ONLY: + raise TypeError("return_copy must be a keyword-only argument") + elif return_copy_arg.default == inspect.Parameter.empty: + return_copy_default = None # unlikely but in this case return_copy would be required + else: + return_copy_default = return_copy_arg.default + assert isinstance(return_copy_default, bool), "return_copy default should be bool" + + @wrapsmethod(func) + def _use_cache_wrapper(instance: SeriesExtensions, *args, **kwargs): + # extract return_copy; return_copy is keyword only, so only have to look in kwargs + return_copy = kwargs.get("return_copy", return_copy_default) + if return_copy is None: # no default case + raise TypeError("Must provide a value for return_copy") + + if not isinstance(return_copy, bool): + raise TypeError("return_copy must be a bool") + + # if we are not storing anything in the cache, just do the function call, no need to search + # still make a copy if copy_bool to make absolutely sure it's not aliasing another object + if self.size == 0: self.clear_cache() - return _return_wrapper(func(instance, *args, **kwargs), return_copy) - - # if cache is empty, will always be a cache miss - if len(self.cache.index) == 0: - return_val = func(instance, *args, **kwargs) - self.cache.loc[len(self.cache.index)] = [ - instance._series["uuid"], - func.__name__, - args, - kwargs, - return_val, - time.time(), - ] - return _return_wrapper(return_val, copy_bool=return_copy) - + return _return_wrapper(func(instance, *args, **kwargs), copy_bool=return_copy) + + # iterate through signature and make dict containing arguments to compare, including defaults + args_dict = {} + for i, (param_name, param) in enumerate(params.items()): + if i == 0 or param_name == "return_copy": + continue # skip self/instance and return_copy + elif i-1 < len(args): + args_dict[param_name] = args[i-1] + elif param_name in kwargs: + args_dict[param_name] = kwargs[param_name] + else: + assert param.default != inspect.Parameter.empty, "must have a default argument or there would be a TypeError" + args_dict[param_name] = param.default + # checking to see if there is a cache hit - for i in range(len(self.cache.index)): + for ind, row in self.cache.iterrows(): if ( - self.cache.iloc[i, 0] == instance._series["uuid"] - and self.cache.iloc[i, 1] == func.__name__ - and _check_args_equality(args, self.cache.iloc[i, 2]) - and _check_args_equality(kwargs, self.cache.iloc[i, 3]) + row.at["uuid"] == instance._series["uuid"] + and row.at["function"] == func.__name__ + and _check_args_equality(args_dict, row.at["kwargs"]) ): - self.cache.iloc[i, 5] = time.time() - return_val = self.cache.iloc[i, 4] - return _return_wrapper(self.cache.iloc[i, 4], copy_bool=return_copy) + self.cache.at[ind, "time_stamp"] = time.time() # not supposed to modify row from iterrows + return _return_wrapper(row.at["return_val"], copy_bool=return_copy) # no cache hit, must check cache limit, and if limit is going to be exceeded...remove least recently used and add new entry # if memory type is 'ITEMS': drop the least recently used and then add new item - if self.storage_type == "ITEMS" and len(self.cache.index) >= self.size: - return_val = func(instance, *args, **kwargs) + return_val = func(instance, *args, **kwargs) + curr_val_size = _get_item_size(return_val) + if self.storage_type == "RAM" and curr_val_size > self.size: + # too big to fit in the cache, and no point in evicting other items, so just return + return _return_wrapper(return_val, copy_bool=return_copy) + + if self.storage_type == "ITEMS" and len(self.cache) >= self.size: self.cache.drop( index=self.cache.sort_values( by=["time_stamp"], ascending=False @@ -139,21 +182,11 @@ def _use_cache(instance, *args, **kwargs): axis=0, inplace=True, ) - self.cache = self.cache.reset_index(drop=True) - self.cache.loc[len(self.cache.index)] = [ - instance._series["uuid"], - func.__name__, - args, - kwargs, - return_val, - time.time(), - ] - return _return_wrapper( - self.cache.iloc[len(self.cache.index) - 1, 4], copy_bool=return_copy - ) + self.cache.reset_index(drop=True, inplace=True) + # if memory type is 'RAM': add new item and then remove least recently used items until cache is under correct size again elif self.storage_type == "RAM": - while self._get_cache_size_bytes() > self.size: + while len(self.cache) > 1 and self._get_cache_size_bytes() + curr_val_size > self.size: # can't do anything if it's empty self.cache.drop( index=self.cache.sort_values( by=["time_stamp"], ascending=False @@ -161,31 +194,23 @@ def _use_cache(instance, *args, **kwargs): axis=0, inplace=True, ) - self.cache = self.cache.reset_index(drop=True) - return_val = func(instance, *args, **kwargs) - self.cache.loc[len(self.cache.index)] = [ - instance._series["uuid"], - func.__name__, - args, - kwargs, - return_val, - time.time(), - ] - # no matter the storage type if size is not going to be exceeded for either, then item can just be added to cache - else: - return_val = func(instance, *args, **kwargs) - self.cache.loc[len(self.cache.index)] = [ - instance._series["uuid"], - func.__name__, - args, - kwargs, - return_val, - time.time(), - ] - + self.cache.reset_index(drop=True, inplace=True) + + # now ready to add to cache + add_time = time.time() + self.cache.loc[len(self.cache)] = [ + instance._series["uuid"], + func.__name__, + args_dict, + return_val, + add_time, + add_time, + curr_val_size, + ] return _return_wrapper(return_val, copy_bool=return_copy) - return _use_cache + return _use_cache_wrapper + def invalidate(self, pre: bool = True, post: bool = True): """ @@ -202,7 +227,7 @@ def invalidate(self, pre: bool = True, post: bool = True): """ def _invalidate(func): - @wraps(func) + @wrapsmethod(func) def __invalidate(instance, *args, **kwargs): u = instance._series["uuid"] diff --git a/mesmerize_core/caiman_extensions/cnmf.py b/mesmerize_core/caiman_extensions/cnmf.py index 533e3c8..9811f4b 100644 --- a/mesmerize_core/caiman_extensions/cnmf.py +++ b/mesmerize_core/caiman_extensions/cnmf.py @@ -6,7 +6,6 @@ from caiman.source_extraction.cnmf import CNMF from caiman.source_extraction.cnmf.cnmf import load_CNMF from caiman.utils.visualization import get_contours as caiman_get_contours -from functools import wraps import os from copy import deepcopy @@ -14,64 +13,50 @@ from .cache import Cache from ..arrays import * from ..arrays._base import LazyArray +from ..utils import wrapsmethod cnmf_cache = Cache() +ComponentString = Literal["all", "good", "bad"] +ComponentSpecifier = Union[np.ndarray, ComponentString, None] # this decorator MUST be called BEFORE caching decorators! def _component_indices_parser(func): - @wraps(func) - def _parser(instance, *args, **kwargs) -> Any: - if "component_indices" in kwargs.keys(): - component_indices: Union[np.ndarray, str, None] = kwargs[ - "component_indices" - ] - elif len(args) > 0: - component_indices = args[0] # always first positional arg in the extensions - else: - component_indices = None # default - - cnmf_obj = instance.get_output() - - # TODO: finally time to learn Python's new switch case - accepted = (np.ndarray, str, type(None)) - if not isinstance(component_indices, accepted): - raise TypeError(f"`component_indices` must be one of type: {accepted}") - + @wrapsmethod(func) + def _parser(instance: 'CNMFExtensions', component_indices: ComponentSpecifier = None, *args, **kwargs): if isinstance(component_indices, np.ndarray): - pass + return func(instance, component_indices, *args, **kwargs) - elif component_indices is None: - component_indices = np.arange(cnmf_obj.estimates.A.shape[1]) + if component_indices is None: + component_indices = "all" if isinstance(component_indices, str): - accepted = ["all", "good", "bad"] - if component_indices not in accepted: - raise ValueError( - f"Accepted `str` values for `component_indices` are: {accepted}" - ) - if component_indices == "all": - component_indices = np.arange(cnmf_obj.estimates.A.shape[1]) + component_indices = np.arange(instance.get_n_components()) elif component_indices == "good": - component_indices = cnmf_obj.estimates.idx_components + component_indices = np.asarray(instance.get_good_components()) elif component_indices == "bad": - component_indices = cnmf_obj.estimates.idx_components_bad - if "component_indices" in kwargs.keys(): - kwargs["component_indices"] = component_indices + component_indices = np.asarray(instance.get_bad_components()) + + else: + raise ValueError( + f"Accepted `str` values for `component_indices` are: {ComponentString.__args__}" + ) else: - args = (component_indices, *args[1:]) + raise TypeError( + "`component_indices` must be one of these types/values: np.ndarray, None, " + + ", ".join(repr(a) for a in ComponentString.__args__)) - return func(instance, *args, **kwargs) + return func(instance, component_indices, *args, **kwargs) return _parser def _check_permissions(func): - @wraps(func) + @wrapsmethod(func) def __check(instance, *args, **kwargs): cnmf_obj_path = instance.get_output_path() @@ -95,7 +80,7 @@ def __init__(self, s: pd.Series): self._series = s @validate("cnmf") - def get_cnmf_memmap(self, mode: str = "r") -> np.ndarray: + def get_cnmf_memmap(self, mode="r") -> np.ndarray: """ Get the CNMF C-order memmap. This should NOT be used for viewing the movie frames use ``caiman.get_input_movie()`` for that purpose. @@ -152,7 +137,7 @@ def get_output_path(self) -> Path: @validate("cnmf") @cnmf_cache.use_cache - def get_output(self, return_copy=True) -> CNMF: + def get_output(self, *, return_copy=True) -> CNMF: """ Parameters ---------- @@ -188,17 +173,12 @@ def get_output(self, return_copy=True) -> CNMF: # Need to create a cache object that takes the item's UUID and returns based on that # collective global cache - return load_CNMF(self.get_output_path()) + return load_CNMF(str(self.get_output_path())) @validate("cnmf") @_component_indices_parser @cnmf_cache.use_cache - def get_masks( - self, - component_indices: Union[np.ndarray, str] = None, - threshold: float = 0.01, - return_copy=True, - ) -> np.ndarray: + def get_masks(self, component_indices: ComponentSpecifier = None, threshold=0.01, *, return_copy=True) -> np.ndarray: """ | Get binary masks of the spatial components at the given ``component_indices``. | Created from ``CNMF.estimates.A`` @@ -226,6 +206,7 @@ def get_masks( shape is [dim_0, dim_1, n_components] """ + component_indices = cast(np.ndarray, component_indices) # guaranteed by index parsing decorator cnmf_obj = self.get_output() @@ -267,12 +248,7 @@ def _get_spatial_contours(cnmf_obj: CNMF, component_indices, swap_dim): @validate("cnmf") @_component_indices_parser @cnmf_cache.use_cache - def get_contours( - self, - component_indices: Union[np.ndarray, str] = None, - swap_dim: bool = True, - return_copy=True, - ) -> Tuple[List[np.ndarray], List[np.ndarray]]: + def get_contours(self, component_indices: ComponentSpecifier = None, swap_dim=True, *, return_copy=True) -> tuple[list[np.ndarray], list[np.ndarray]]: """ Get the contour and center of mass for each spatial footprint Note, the centers of mass are different from those computed by CaImAn. @@ -297,12 +273,13 @@ def get_contours( Returns ------- - Tuple[List[np.ndarray], List[np.ndarray]] + tuple[list[np.ndarray], list[np.ndarray]] | (List[coordinates array], List[centers of masses array]) | each array of coordinates is 2D, [xs, ys] | each center of mass is [x, y] """ + component_indices = cast(np.ndarray, component_indices) # guaranteed by index parsing decorator cnmf_obj = self.get_output() contours = self._get_spatial_contours(cnmf_obj, component_indices, swap_dim) @@ -322,13 +299,7 @@ def get_contours( @validate("cnmf") @_component_indices_parser @cnmf_cache.use_cache - def get_temporal( - self, - component_indices: Union[np.ndarray, str] = None, - add_background: bool = False, - add_residuals: bool = False, - return_copy=True, - ) -> np.ndarray: + def get_temporal(self, component_indices: ComponentSpecifier = None, add_background=False, add_residuals=False, *, return_copy=True) -> np.ndarray: """ Get the temporal components for this CNMF item, basically ``CNMF.estimates.C`` @@ -378,17 +349,16 @@ def get_temporal( plot.show() """ + component_indices = cast(np.ndarray, component_indices) # guaranteed by index parsing decorator cnmf_obj = self.get_output() - C = cnmf_obj.estimates.C[component_indices] - f = cnmf_obj.estimates.f - - temporal = C + temporal = cnmf_obj.estimates.C[component_indices] if add_background: - temporal += f - elif add_residuals: + temporal += cnmf_obj.estimates.f + + if add_residuals: temporal += cnmf_obj.estimates.YrA[component_indices] return temporal @@ -396,12 +366,7 @@ def get_temporal( @validate("cnmf") @_component_indices_parser @cnmf_cache.use_cache - def get_rcm( - self, - component_indices: Union[np.ndarray, str] = None, - temporal_components: np.ndarray = None, - return_copy=False, - ) -> LazyArrayRCM: + def get_rcm(self, component_indices: ComponentSpecifier = None, temporal_components: Optional[np.ndarray] = None, *, return_copy=False) -> LazyArrayRCM: """ Return the reconstructed movie with no background, i.e. ``A ⊗ C``, as a ``LazyArray``. This is an array that performs lazy computation of the reconstructed movie only upon indexing. @@ -457,6 +422,7 @@ def get_rcm( iw = ImageWidget(data=rcm) iw.show() """ + component_indices = cast(np.ndarray, component_indices) # guaranteed by index parsing decorator cnmf_obj = self.get_output() @@ -487,12 +453,17 @@ def get_rcm( @validate("cnmf") @cnmf_cache.use_cache - def get_rcb( - self, - ) -> LazyArrayRCB: + def get_rcb(self, *, return_copy=False) -> LazyArrayRCB: """ Return the reconstructed background, ``(b ⊗ f)`` + Parameters + ---------- + return_copy: bool, default ``False`` + | if ``True`` returns a copy of the cached value in memory. + | if ``False`` returns the same object as the cached value in memory + | ``False`` is used by default when returning ``LazyArrays`` for technical reasons + Returns ------- LazyArrayRCB @@ -543,10 +514,17 @@ def get_rcb( @validate("cnmf") @cnmf_cache.use_cache - def get_residuals(self) -> LazyArrayResiduals: + def get_residuals(self, *, return_copy=False) -> LazyArrayResiduals: """ Return residuals, ``Y - (A ⊗ C) - (b ⊗ f)`` + Parameters + ---------- + return_copy: bool, default ``False`` + | if ``True`` returns a copy of the cached value in memory. + | if ``False`` returns the same object as the cached value in memory + | ``False`` is used by default when returning ``LazyArrays`` for technical reasons + Returns ------- LazyArrayResiduals @@ -658,7 +636,7 @@ def run_detrend_dfof( @_component_indices_parser @cnmf_cache.use_cache def get_detrend_dfof( - self, component_indices: Union[np.ndarray, str] = None, return_copy: bool = True + self, component_indices: ComponentSpecifier = None, *, return_copy: bool = True ): """ Get the detrended dF/F0 curves after calling ``run_detrend_dfof``. @@ -684,6 +662,7 @@ def get_detrend_dfof( shape is [n_components, n_frames] """ + component_indices = cast(np.ndarray, component_indices) # guaranteed by index parsing decorator cnmf_obj = self.get_output() if cnmf_obj.estimates.F_dff is None: @@ -754,10 +733,18 @@ def run_eval(self, params: dict) -> None: self._series["params"]["eval"] = deepcopy(params) @validate("cnmf") - def get_good_components(self) -> np.ndarray: + @cnmf_cache.use_cache + def get_good_components(self, *, return_copy=True) -> np.ndarray: """ get the good component indices, ``Estimates.idx_components`` + Parameters + ---------- + return_copy: bool + | if ``True`` returns a copy of the cached value in memory. + | if ``False`` returns the same object as the cached value in memory, not recommend this could result in strange unexpected behavior. + | In general you want a copy of the cached value. + Returns ------- np.ndarray @@ -769,10 +756,18 @@ def get_good_components(self) -> np.ndarray: return cnmf_obj.estimates.idx_components @validate("cnmf") - def get_bad_components(self) -> np.ndarray: + @cnmf_cache.use_cache + def get_bad_components(self, *, return_copy=True) -> np.ndarray: """ get the bad component indices, ``Estimates.idx_components_bad`` + Parameters + ---------- + return_copy: bool + | if ``True`` returns a copy of the cached value in memory. + | if ``False`` returns the same object as the cached value in memory, not recommend this could result in strange unexpected behavior. + | In general you want a copy of the cached value. + Returns ------- np.ndarray @@ -782,3 +777,24 @@ def get_bad_components(self) -> np.ndarray: cnmf_obj = self.get_output() return cnmf_obj.estimates.idx_components_bad + + @validate("cnmf") + @cnmf_cache.use_cache + def get_n_components(self, *, return_copy=True) -> int: + """ + get total number of components (good + bad) + + Parameters + ---------- + return_copy: bool + | if ``True`` returns a copy of the cached value in memory. + | if ``False`` returns the same object as the cached value in memory, not recommend this could result in strange unexpected behavior. + | In general you want a copy of the cached value. + + Returns + ------- + int + number of components + """ + cnmf_obj = self.get_output() + return cnmf_obj.estimates.A.shape[1] diff --git a/mesmerize_core/utils.py b/mesmerize_core/utils.py index e168017..4212747 100644 --- a/mesmerize_core/utils.py +++ b/mesmerize_core/utils.py @@ -33,13 +33,26 @@ MESMERIZE_LRU_CACHE = 10 +def wrapsmethod(wrapper): + """ + functools.wraps doesn't type its return value properly for use as a method, + so use this to disable functools.wraps when type checking (since it only matters at runtime) + """ + def decorator(fn): + if TYPE_CHECKING: + return fn + else: + return wraps(wrapper)(fn) + return decorator + + def warning_experimental(more_info: str = ""): """ decorator to warn the user that the function is experimental """ def catcher(func): - @wraps(func) + @wrapsmethod(func) def fn(self, *args, **kwargs): warn( f"You are trying to use the following experimental feature, " diff --git a/tests/test_core.py b/tests/test_core.py index b8a7419..8c1fd75 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1179,36 +1179,32 @@ def test_cache(): # testing that cache size limits work cnmf.cnmf_cache.set_maxsize("1M") - cnmf_output = df.iloc[-1].cnmf.get_output() - hex_get_output = hex(id(cnmf_output)) + df.iloc[-1].cnmf.get_output() # cache entry 0 (get_output) cache = cnmf.cnmf_cache.get_cache() hex1 = hex(id(cache[cache["function"] == "get_output"]["return_val"].item())) - # assert(hex(id(df.iloc[-1].cnmf.get_output(copy=False))) == hex1) - # assert(hex_get_output != hex1) + time_stamp1 = cache[cache["function"] == "get_output"]["time_stamp"].item() - df.iloc[-1].cnmf.get_temporal("good") - df.iloc[-1].cnmf.get_contours("good") - df.iloc[-1].cnmf.get_masks("good") - df.iloc[-1].cnmf.get_temporal(np.arange(7)) - df.iloc[-1].cnmf.get_temporal(np.arange(8)) - df.iloc[-1].cnmf.get_temporal(np.arange(9)) - df.iloc[-1].cnmf.get_temporal(np.arange(6)) - df.iloc[-1].cnmf.get_temporal(np.arange(5)) - df.iloc[-1].cnmf.get_temporal(np.arange(4)) - df.iloc[-1].cnmf.get_temporal(np.arange(3)) - df.iloc[-1].cnmf.get_masks(np.arange(8)) - df.iloc[-1].cnmf.get_masks(np.arange(9)) - df.iloc[-1].cnmf.get_masks(np.arange(7)) - df.iloc[-1].cnmf.get_masks(np.arange(6)) - df.iloc[-1].cnmf.get_masks(np.arange(5)) - df.iloc[-1].cnmf.get_masks(np.arange(4)) - df.iloc[-1].cnmf.get_masks(np.arange(3)) - time_stamp2 = cache[cache["function"] == "get_output"]["time_stamp"].item() - hex2 = hex(id(cache[cache["function"] == "get_output"]["return_val"].item())) - assert cache[cache["function"] == "get_output"].index.size == 1 + df.iloc[-1].cnmf.get_temporal("good") # cache entry 1 (get_good_components) + 2 (get_temporal) + df.iloc[-1].cnmf.get_contours("good") # cache entry 3 (get_contours) + df.iloc[-1].cnmf.get_masks("good") # cache entry 4 (get_masks) + df.iloc[-1].cnmf.get_temporal(np.arange(7)) # cache entry 5 (get_temporal) + df.iloc[-1].cnmf.get_temporal(np.arange(8)) # cache entry 6, 2 gets evicted + + assert hex(id(cache)) == hex(id(cnmf.cnmf_cache.get_cache())), \ + "cache object should still be the same after evicting cache items" + # after adding enough items for cache to exceed max size, cache should remove least recently used items until # size is back under max - assert len(cnmf.cnmf_cache.get_cache().index) == 17 + assert len(cache) == 6 + + output_items = cache[cache["function"] == "get_output"] + assert len(output_items) > 0, "output should not be evicted since it's accessed for every other function" + assert len(output_items) == 1, "output shuould not be duplicated in the cache" + + time_stamp2 = output_items["time_stamp"].item() + hex2 = hex(id(output_items["return_val"].item())) + assert cache[cache["function"] == "get_output"].index.size == 1 + # the time stamp to get_output the second time should be greater than the original time # stamp because the cached item is being returned and therefore will have been accessed more recently assert time_stamp2 > time_stamp1 @@ -1285,7 +1281,7 @@ def test_cache(): df = load_batch(batch_path) - cnmf.cnmf_cache.set_maxsize("1M") + cnmf.cnmf_cache.set_maxsize("2M") df.iloc[1].cnmf.get_output() # cnmf output df.iloc[-1].cnmf.get_output() # cnmfe output @@ -1326,13 +1322,35 @@ def test_cache(): # test for copy # if return_copy=True, then hex id of calls to the same function should be false output = df.iloc[1].cnmf.get_output() - assert hex(id(output)) != hex( - id(cache.sort_values(by=["time_stamp"], ascending=True).iloc[-1]) - ) - # if return_copy=False, then hex id of calls to the same function should be true - output = df.iloc[1].cnmf.get_output(return_copy=False) + output_cache_entry = cache.sort_values(by=["time_stamp"], ascending=True).iloc[-1] + hex_orig = hex(id(output_cache_entry["return_val"])) + time_orig = output_cache_entry["added_time"] + assert hex(id(output)) != hex_orig + + # return_copy should't be considered when comparing function calls + # better to compare added times than hex because 2 different cache entries could refer to the same object + # (not for get_output but yes for other functions) output2 = df.iloc[1].cnmf.get_output(return_copy=False) - assert hex(id(output)) == hex(id(output2)) - assert hex(id(cnmf.cnmf_cache.get_cache().iloc[-1]["return_val"])) == hex( - id(output) - ) + last_cache_entry = cache.sort_values(by=["time_stamp"], ascending=True).iloc[-1] + assert last_cache_entry["added_time"] == time_orig + + # if return_copy=False, then hex id of calls to the same function should be true + output3 = df.iloc[1].cnmf.get_output(return_copy=False) + assert hex(id(output3)) == hex(id(output2)) + last_cache_entry = cache.sort_values(by=["time_stamp"], ascending=True).iloc[-1] + assert last_cache_entry["added_time"] == time_orig + + # shouldn't matter for comparison whether arguments are passed positionally, by keyword, + # or by default if their args are the same + same1 = df.iloc[1].cnmf.get_temporal("good", return_copy=False) # add_background=False by default + same1_time = cache.sort_values(by=["time_stamp"], ascending=True).iloc[-1]["added_time"] + same2 = df.iloc[1].cnmf.get_temporal("good", False, return_copy=False) + same2_time = cache.sort_values(by=["time_stamp"], ascending=True).iloc[-1]["added_time"] + same3 = df.iloc[1].cnmf.get_temporal("good", add_background=False, return_copy=False) + same3_time = cache.sort_values(by=["time_stamp"], ascending=True).iloc[-1]["added_time"] + different = df.iloc[1].cnmf.get_temporal("good", add_background=True, return_copy=False) + different_time = cache.sort_values(by=["time_stamp"], ascending=True).iloc[-1]["added_time"] + + assert hex(id(same1)) == hex(id(same2)) and same1_time == same2_time, "Matching default argument should cause hit" + assert hex(id(same2)) == hex(id(same3)) and same2_time == same3_time, "Matching keyword/non-keyword arguments should cause hit" + assert hex(id(same3)) != hex(id(different)) and same3_time != different_time, "Non-matching arguments should cause miss"