Skip to content

Commit ad05c13

Browse files
committed
Ensure correct params are set when Ain is provided; add tests
1 parent 7c0bee9 commit ad05c13

File tree

3 files changed

+207
-15
lines changed

3 files changed

+207
-15
lines changed

mesmerize_core/algorithms/cnmf.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ def run_algo(batch_path, uuid, data_path: str = None):
8888
Ain = np.load(Ain_path_abs, allow_pickle=True)
8989
if Ain.size == 1: # sparse array loaded as object
9090
Ain = Ain.item()
91+
92+
# force params needed for seeded CNMF
93+
cnmf_params.change_params({'patch': {'rf': None, 'only_init': False}})
9194
else:
9295
Ain = None
9396

@@ -96,7 +99,7 @@ def run_algo(batch_path, uuid, data_path: str = None):
9699

97100
print("fitting images")
98101
cnm.fit(images)
99-
#
102+
100103
if "refit" in params.keys():
101104
if params["refit"] is True:
102105
print("refitting")

mesmerize_core/algorithms/cnmfe.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,25 +50,25 @@ def run_algo(batch_path, uuid, data_path: str = None):
5050

5151
try:
5252
# force the CNMFE params
53-
cnmfe_params_dict = {
53+
cnmfe_params = {
5454
"method_init": "corr_pnr",
5555
"n_processes": n_processes,
5656
"only_init": True, # for 1p
5757
"center_psf": True, # for 1p
5858
"normalize_init": False, # for 1p
5959
}
6060

61-
params_dict = {**cnmfe_params_dict, **params["main"]}
61+
params_dict = {**cnmfe_params, **params["main"]}
6262

63-
cnmfe_params_dict = CNMFParams(params_dict=params_dict)
63+
cnmfe_params = CNMFParams(params_dict=params_dict)
6464

6565
print("making memmap")
6666
fname_new = cm.save_memmap(
6767
[input_movie_path],
6868
base_name=f"{uuid}_cnmf-memmap_",
6969
order="C",
7070
dview=dview,
71-
var_name_hdf5=cnmfe_params_dict.data['var_name_hdf5']
71+
var_name_hdf5=cnmfe_params.data['var_name_hdf5']
7272
)
7373

7474
Yr, dims, T = cm.load_memmap(fname_new)
@@ -94,11 +94,14 @@ def run_algo(batch_path, uuid, data_path: str = None):
9494
Ain = np.load(Ain_path_abs, allow_pickle=True)
9595
if Ain.size == 1: # sparse array loaded as object
9696
Ain = Ain.item()
97+
98+
# force params needed for seeded CNMFE
99+
cnmfe_params.change_params({'patch': {'rf': None, 'only_init': False}})
97100
else:
98101
Ain = None
99102

100103
cnm = cnmf.CNMF(
101-
n_processes=n_processes, dview=dview, params=cnmfe_params_dict, Ain=Ain
104+
n_processes=n_processes, dview=dview, params=cnmfe_params, Ain=Ain
102105
)
103106
print("Performing CNMFE")
104107
cnm.fit(images)

tests/test_core.py

Lines changed: 195 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
from caiman.utils.utils import load_dict_from_hdf5
55
from caiman.source_extraction.cnmf.cnmf import CNMF
6+
from caiman.base.rois import extract_binary_masks_from_structural_channel
67
import numpy.testing
78
import pandas as pd
89
from mesmerize_core import (
@@ -34,23 +35,22 @@
3435
from copy import deepcopy
3536

3637
# don't call "resolve" on these - want to make sure we can handle non-canonical paths correctly
37-
tmp_dir = Path(os.path.dirname(os.path.abspath(__file__)), "test data", "tmp")
38-
vid_dir = Path(os.path.dirname(os.path.abspath(__file__)), "test data", "videos")
39-
ground_truths_dir = Path(
40-
os.path.dirname(os.path.abspath(__file__)), "test data", "ground_truths"
41-
)
42-
ground_truths_file = Path(
43-
os.path.dirname(os.path.abspath(__file__)), "test data", "ground_truths.zip"
44-
)
38+
testdata_dir = Path(os.path.dirname(os.path.abspath(__file__)), "test data")
39+
tmp_dir = testdata_dir / "tmp"
40+
vid_dir = testdata_dir / "videos"
41+
seed_dir = testdata_dir / "seeds"
42+
ground_truths_dir = testdata_dir / "ground_truths"
43+
ground_truths_file = testdata_dir / "ground_truths.zip"
4544

4645
os.makedirs(tmp_dir, exist_ok=True)
4746
os.makedirs(vid_dir, exist_ok=True)
47+
os.makedirs(seed_dir, exist_ok=True)
4848
os.makedirs(ground_truths_dir, exist_ok=True)
4949

5050

5151
def _download_ground_truths():
5252
print(f"Downloading ground truths")
53-
url = f"https://zenodo.org/record/14934525/files/ground_truths.zip"
53+
url = f"https://zenodo.org/records/17059175/files/ground_truths.zip"
5454

5555
# basically from https://stackoverflow.com/questions/37573483/progress-bar-while-download-file-over-http-with-requests/37573701
5656
response = requests.get(url, stream=True)
@@ -141,6 +141,12 @@ def _create_tmp_batch() -> tuple[pd.DataFrame, str]:
141141
return df, fname
142142

143143

144+
def make_test_seed(input_data: np.ndarray):
145+
"""Function call used to create Ain for testing"""
146+
mean_proj = np.mean(input_data, axis=0)
147+
return extract_binary_masks_from_structural_channel(mean_proj, gSig=3)[0]
148+
149+
144150
def test_create_batch():
145151
df, fname = _create_tmp_batch()
146152

@@ -1336,3 +1342,183 @@ def test_cache():
13361342
assert hex(id(cnmf.cnmf_cache.get_cache().iloc[-1]["return_val"])) == hex(
13371343
id(output)
13381344
)
1345+
1346+
1347+
def test_seeded_cnmf():
1348+
"""Test seeded CNNF (Ain)"""
1349+
set_parent_raw_data_path(vid_dir)
1350+
algo = "mcorr"
1351+
1352+
df, batch_path = _create_tmp_batch()
1353+
1354+
batch_path = Path(batch_path)
1355+
batch_dir = batch_path.parent
1356+
batch_dir_canon = batch_dir.resolve()
1357+
1358+
input_movie_path = get_datafile(algo)
1359+
print(input_movie_path)
1360+
1361+
df.caiman.add_item(
1362+
algo=algo,
1363+
item_name=f"test-{algo}",
1364+
input_movie_path=input_movie_path,
1365+
params=test_params[algo],
1366+
)
1367+
1368+
df.iloc[-1].caiman.run()
1369+
df = load_batch(batch_path)
1370+
1371+
assert df.iloc[-1]["outputs"]["success"] is True
1372+
assert df.iloc[-1]["outputs"]["traceback"] is None
1373+
1374+
# make seed
1375+
mcorr_output = df.iloc[-1].mcorr.get_output()
1376+
seed = make_test_seed(mcorr_output)
1377+
seed_path = seed_dir / "Ain_cnmf.npy"
1378+
np.save(seed_path, seed)
1379+
1380+
algo = "cnmf"
1381+
print("Testing seeded cnmf")
1382+
input_movie_path = df.iloc[-1].mcorr.get_output_path()
1383+
seeded_params = {
1384+
**test_params[algo],
1385+
"Ain_path": seed_path,
1386+
"refit": False
1387+
}
1388+
1389+
df.caiman.add_item(
1390+
algo=algo,
1391+
item_name=f"test-seeded-{algo}",
1392+
input_movie_path=input_movie_path,
1393+
params=seeded_params
1394+
)
1395+
1396+
assert df.iloc[-1]["algo"] == algo
1397+
assert df.iloc[-1]["item_name"] == f"test-seeded-{algo}"
1398+
assert df.iloc[-1]["params"] == seeded_params
1399+
assert df.iloc[-1]["outputs"] is None
1400+
try:
1401+
UUID(df.iloc[-1]["uuid"])
1402+
except:
1403+
pytest.fail("Something wrong with setting UUID for batch items")
1404+
print("cnmf input_movie_path:", df.iloc[-1]["input_movie_path"])
1405+
assert batch_dir_canon.joinpath(df.iloc[-1]["input_movie_path"]) == input_movie_path
1406+
1407+
df.iloc[-1].caiman.run()
1408+
1409+
df = load_batch(batch_path)
1410+
1411+
with pd.option_context("display.max_rows", None, "display.max_columns", None):
1412+
print(df)
1413+
1414+
pprint(df.iloc[-1]["outputs"], width=-1)
1415+
print(df.iloc[-1]["outputs"]["traceback"])
1416+
1417+
assert df.iloc[-1]["outputs"]["success"] is True
1418+
assert df.iloc[-1]["outputs"]["traceback"] is None
1419+
1420+
# test to check cnmf get_masks()
1421+
cnmf_spatial_masks = df.iloc[-1].cnmf.get_masks("good")
1422+
cnmf_spatial_masks_actual = numpy.load(
1423+
ground_truths_dir.joinpath("cnmf_seeded", "spatial_masks.npy")
1424+
)
1425+
numpy.testing.assert_array_equal(cnmf_spatial_masks, cnmf_spatial_masks_actual)
1426+
1427+
# test to check get_temporal()
1428+
cnmf_temporal_components = df.iloc[-1].cnmf.get_temporal("good")
1429+
cnmf_temporal_components_actual = numpy.load(
1430+
ground_truths_dir.joinpath("cnmf_seeded", "temporal_components.npy")
1431+
)
1432+
numpy.testing.assert_allclose(
1433+
cnmf_temporal_components, cnmf_temporal_components_actual, rtol=1e-2, atol=1e-10
1434+
)
1435+
1436+
1437+
def test_seeded_cnmfe():
1438+
set_parent_raw_data_path(vid_dir)
1439+
1440+
df, batch_path = _create_tmp_batch()
1441+
1442+
batch_path = Path(batch_path)
1443+
batch_dir = batch_path.parent
1444+
batch_dir_canon = batch_dir.resolve()
1445+
1446+
input_movie_path = get_datafile("cnmfe")
1447+
print(input_movie_path)
1448+
df.caiman.add_item(
1449+
algo="mcorr",
1450+
item_name="test-cnmfe-mcorr",
1451+
input_movie_path=input_movie_path,
1452+
params=test_params["mcorr"],
1453+
)
1454+
df.iloc[-1].caiman.run()
1455+
1456+
df = load_batch(batch_path)
1457+
1458+
# Test if running seeded cnmfe works
1459+
# this seed is actually trash for CNMFE but just see if it's consistent
1460+
mcorr_output = df.iloc[-1].mcorr.get_output()
1461+
seed = make_test_seed(mcorr_output)
1462+
seed_path = seed_dir / "Ain_cnmfe.npy"
1463+
np.save(seed_path, seed)
1464+
1465+
print("testing seeded cnmfe")
1466+
algo = "cnmfe"
1467+
param_name = "cnmfe_full"
1468+
input_movie_path = df.iloc[0].mcorr.get_output_path()
1469+
print(input_movie_path)
1470+
seeded_params = {
1471+
**test_params[param_name],
1472+
"Ain_path": seed_path,
1473+
"refit": False
1474+
}
1475+
1476+
df.caiman.add_item(
1477+
algo=algo,
1478+
item_name=f"test-seeded-{algo}",
1479+
input_movie_path=input_movie_path,
1480+
params=seeded_params,
1481+
)
1482+
1483+
assert df.iloc[-1]["algo"] == algo
1484+
assert df.iloc[-1]["item_name"] == f"test-seeded-{algo}"
1485+
assert df.iloc[-1]["params"] == seeded_params
1486+
assert df.iloc[-1]["outputs"] is None
1487+
try:
1488+
UUID(df.iloc[-1]["uuid"])
1489+
except:
1490+
pytest.fail("Something wrong with setting UUID for batch items")
1491+
1492+
assert (
1493+
batch_dir_canon.joinpath(df.iloc[-1]["input_movie_path"])
1494+
== batch_dir_canon.joinpath(df.iloc[0].mcorr.get_output_path())
1495+
== df.paths.resolve(df.iloc[-1]["input_movie_path"])
1496+
)
1497+
1498+
df.iloc[-1].caiman.run()
1499+
df = load_batch(batch_path)
1500+
1501+
with pd.option_context("display.max_rows", None, "display.max_columns", None):
1502+
print(df)
1503+
1504+
pprint(df.iloc[-1]["outputs"], width=-1)
1505+
print(df.iloc[-1]["outputs"]["traceback"])
1506+
1507+
assert df.iloc[-1]["outputs"]["success"] is True
1508+
assert df.iloc[-1]["outputs"]["traceback"] is None
1509+
1510+
# test to check cnmf get_masks()
1511+
cnmf_spatial_masks = df.iloc[-1].cnmf.get_masks("good")
1512+
cnmf_spatial_masks_actual = numpy.load(
1513+
ground_truths_dir.joinpath("cnmfe_seeded", "spatial_masks.npy")
1514+
)
1515+
numpy.testing.assert_array_equal(cnmf_spatial_masks, cnmf_spatial_masks_actual)
1516+
1517+
# test to check get_temporal()
1518+
cnmf_temporal_components = df.iloc[-1].cnmf.get_temporal("good")
1519+
cnmf_temporal_components_actual = numpy.load(
1520+
ground_truths_dir.joinpath("cnmfe_seeded", "temporal_components.npy")
1521+
)
1522+
numpy.testing.assert_allclose(
1523+
cnmf_temporal_components, cnmf_temporal_components_actual, rtol=1e-2, atol=1e-10
1524+
)

0 commit comments

Comments
 (0)