Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions mesmerize_core/algorithms/cnmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,24 @@ def run_algo(batch_path, uuid, data_path: str = None):
backend="local", n_processes=None, single_thread=False
)

# load Ain if given
if 'Ain_path' in params and params['Ain_path'] is not None:
Ain_path_abs = output_dir / params['Ain_path'] # resolve relative to output dir
Ain = np.load(Ain_path_abs, allow_pickle=True)
if Ain.size == 1: # sparse array loaded as object
Ain = Ain.item()

# force params needed for seeded CNMF
cnmf_params.change_params({'patch': {'rf': None, 'only_init': False}})
else:
Ain = None

print("performing CNMF")
cnm = cnmf.CNMF(n_processes, params=cnmf_params, dview=dview)
cnm = cnmf.CNMF(n_processes, params=cnmf_params, dview=dview, Ain=Ain)

print("fitting images")
cnm.fit(images)
#

if "refit" in params.keys():
if params["refit"] is True:
print("refitting")
Expand Down
53 changes: 37 additions & 16 deletions mesmerize_core/algorithms/cnmfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,30 @@ def run_algo(batch_path, uuid, data_path: str = None):
c, dview, n_processes = cm.cluster.setup_cluster(
backend="local", n_processes=n_processes, single_thread=False
)

try:
# force the CNMFE params
cnmfe_params = {
"method_init": "corr_pnr",
"n_processes": n_processes,
"only_init": True, # for 1p
"center_psf": True, # for 1p
"normalize_init": False, # for 1p
}

params_dict = {**cnmfe_params, **params["main"]}

cnmfe_params = CNMFParams(params_dict=params_dict)

print("making memmap")
fname_new = cm.save_memmap(
[input_movie_path], base_name=f"{uuid}_cnmf-memmap_", order="C", dview=dview
[input_movie_path],
base_name=f"{uuid}_cnmf-memmap_",
order="C",
dview=dview,
var_name_hdf5=cnmfe_params.data['var_name_hdf5']
)

print("making memmap")
Yr, dims, T = cm.load_memmap(fname_new)
images = np.reshape(Yr.T, [T] + list(dims), order="F")

Expand All @@ -69,19 +86,23 @@ def run_algo(batch_path, uuid, data_path: str = None):

d = dict() # for output

# force the CNMFE params
cnmfe_params_dict = {
"method_init": "corr_pnr",
"n_processes": n_processes,
"only_init": True, # for 1p
"center_psf": True, # for 1p
"normalize_init": False, # for 1p
}

params_dict = {**cnmfe_params_dict, **params["main"]}

cnmfe_params_dict = CNMFParams(params_dict=params_dict)
cnm = cnmf.CNMF(n_processes=n_processes, dview=dview, params=cnmfe_params_dict)
# load Ain if given
if "Ain_path" in params and params["Ain_path"] is not None:
Ain_path_abs = (
output_dir / params["Ain_path"]
) # resolve relative to output dir
Ain = np.load(Ain_path_abs, allow_pickle=True)
if Ain.size == 1: # sparse array loaded as object
Ain = Ain.item()

# force params needed for seeded CNMFE
cnmfe_params.change_params({'patch': {'rf': None, 'only_init': False}})
else:
Ain = None

cnm = cnmf.CNMF(
n_processes=n_processes, dview=dview, params=cnmfe_params, Ain=Ain
)
print("Performing CNMFE")
cnm.fit(images)
print("evaluating components")
Expand Down
66 changes: 5 additions & 61 deletions mesmerize_core/caiman_extensions/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
get_parent_raw_data_path,
load_batch,
)
from ..utils import IS_WINDOWS, make_runfile, warning_experimental
from ..utils import IS_WINDOWS, make_runfile, warning_experimental, get_params_diffs
from .cnmf import cnmf_cache
from .. import algorithms
from ..movie_readers import default_reader
Expand Down Expand Up @@ -349,22 +349,7 @@ def get_params_diffs(self, algo: str, item_name: str) -> pd.DataFrame:
index of the original DataFrame

"""

def flatten_params(params_dict: dict):
"""
Produce a flat dict with one entry for each parameter in the passed dict.
If params_dict['main'] is nested one level (e.g., {'init': {'K': 5}, 'merging': {'merge_thr': 0.85}}...),
each key in the output is <outerKey>.<innerKey>, e.g., [(init.K, 5), (merging.merge_thr, 0.85)]
"""
params = {}
for key1, val1 in params_dict.items():
if isinstance(val1, dict): # nested
for key2, val2 in val1.items():
params[f"{key1}.{key2}"] = val2
else:
params[key1] = val1
return params


sub_df = self._df[self._df["item_name"] == item_name]
sub_df = sub_df[sub_df["algo"] == algo]

Expand All @@ -373,52 +358,11 @@ def flatten_params(params_dict: dict):
f"The given `item_name`: {item_name}, does not exist in the DataFrame"
)

# get flattened parameters for each of the filtered items
params_flat = sub_df.params.map(lambda p: flatten_params(p["main"]))

# build list of params that differ between different parameter sets
common_params = deepcopy(
params_flat.iat[0]
) # holds the common value for parameters found in all sets (so far)
varying_params = (
set()
) # set of parameter keys that appear in not all sets or with varying values

for this_params in params_flat.iloc[1:]:
# first, anything that's not in both this dict and the common set is considered varying
common_paramset = set(common_params.keys())
for not_common_key in common_paramset.symmetric_difference(
this_params.keys()
):
varying_params.add(not_common_key)
if not_common_key in common_paramset:
del common_params[not_common_key]
common_paramset.remove(not_common_key)

# second, look at params in the common set and remove any that differ for this set
for (
key
) in (
common_paramset
): # iterate over this set rather than dict itself to avoid issues when deleting entries
if not np.array_equal(
common_params[key], this_params[key]
): # (should also work for scalars/arbitrary objects)
varying_params.add(key)
del common_params[key]

# gives a list where each item is a dict that has the unique params that correspond to a row
# the indices of this series correspond to the index of the row in the parent dataframe
diffs = params_flat.map(
lambda p: {
key: p[key] if key in p else "<default>" for key in varying_params
}
)
params_list = sub_df.params.tolist()
diffs = get_params_diffs(params_list)

# return as a nicely formatted dataframe
diffs_df = pd.DataFrame.from_dict(diffs.tolist(), dtype=object).set_index(
diffs.index
)
diffs_df = pd.DataFrame.from_dict(diffs, dtype=object).set_index(sub_df.index)

return diffs_df

Expand Down
49 changes: 49 additions & 0 deletions mesmerize_core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import sys
from tempfile import NamedTemporaryFile
from subprocess import check_call
from copy import deepcopy
import pandas as pd
import shlex
import mslex

Expand Down Expand Up @@ -170,3 +172,50 @@ def _organize_coordinates(contour: dict):
coors = coors[~np.isnan(coors).any(axis=1)]

return coors


def flatten_params(params_dict: dict) -> dict:
"""
Produce a flat dict with one entry for each parameter in the passed dict.
If params_dict['main'] is nested one level (e.g., {'init': {'K': 5}, 'merging': {'merge_thr': 0.85}}...),
each key in the output is <outerKey>.<innerKey>, e.g., [(init.K, 5), (merging.merge_thr, 0.85)]
"""
params = {}
for key1, val1 in params_dict.items():
if key1 == "main":
# recursively step into "main" params
params.update(flatten_params(val1))
elif isinstance(val1, dict): # nested
for key2, val2 in val1.items():
params[f"{key1}.{key2}"] = val2
else:
params[key1] = val1
return params


def get_params_diffs(params: Sequence[dict]) -> list[dict]:
"""Compute differences between params used for mesmerize"""
# get flattened parameters for each of the filtered items
params_flat = list(map(flatten_params, params))

# build list of params that differ between different parameter sets
common_params = deepcopy(params_flat[0]) # holds the common value for parameters found in all sets (so far)
varying_params = set() # set of parameter keys that appear in not all sets or with varying values

for this_params in params_flat[1:]:
# first, anything that's not in both this dict and the common set is considered varying
common_paramset = set(common_params.keys())
for not_common_key in common_paramset.symmetric_difference(this_params.keys()):
varying_params.add(not_common_key)
if not_common_key in common_paramset:
del common_params[not_common_key]
common_paramset.remove(not_common_key)

# second, look at params in the common set and remove any that differ for this set
for key in common_paramset: # iterate over this set rather than dict itself to avoid issues when deleting entries
if not np.array_equal(common_params[key], this_params[key]): # (should also work for scalars/arbitrary objects)
varying_params.add(key)
del common_params[key]

# gives a list where each item is a dict that has the unique params that correspond to a row
return [{key: p[key] if key in p else "<default>" for key in varying_params} for p in params_flat]
Loading
Loading