33import numpy as np
44from caiman .utils .utils import load_dict_from_hdf5
55from caiman .source_extraction .cnmf .cnmf import CNMF
6+ from caiman .base .rois import extract_binary_masks_from_structural_channel
67import numpy .testing
78import pandas as pd
89from mesmerize_core import (
3435from copy import deepcopy
3536
3637# don't call "resolve" on these - want to make sure we can handle non-canonical paths correctly
37- tmp_dir = Path (os .path .dirname (os .path .abspath (__file__ )), "test data" , "tmp" )
38- vid_dir = Path (os .path .dirname (os .path .abspath (__file__ )), "test data" , "videos" )
39- ground_truths_dir = Path (
40- os .path .dirname (os .path .abspath (__file__ )), "test data" , "ground_truths"
41- )
42- ground_truths_file = Path (
43- os .path .dirname (os .path .abspath (__file__ )), "test data" , "ground_truths.zip"
44- )
38+ testdata_dir = Path (os .path .dirname (os .path .abspath (__file__ )), "test data" )
39+ tmp_dir = testdata_dir / "tmp"
40+ vid_dir = testdata_dir / "videos"
41+ seed_dir = testdata_dir / "seeds"
42+ ground_truths_dir = testdata_dir / "ground_truths"
43+ ground_truths_file = testdata_dir / "ground_truths.zip"
4544
4645os .makedirs (tmp_dir , exist_ok = True )
4746os .makedirs (vid_dir , exist_ok = True )
47+ os .makedirs (seed_dir , exist_ok = True )
4848os .makedirs (ground_truths_dir , exist_ok = True )
4949
5050
5151def _download_ground_truths ():
5252 print (f"Downloading ground truths" )
53- url = f"https://zenodo.org/record/14934525 /files/ground_truths.zip"
53+ url = f"https://zenodo.org/records/17059175 /files/ground_truths.zip"
5454
5555 # basically from https://stackoverflow.com/questions/37573483/progress-bar-while-download-file-over-http-with-requests/37573701
5656 response = requests .get (url , stream = True )
@@ -141,6 +141,12 @@ def _create_tmp_batch() -> tuple[pd.DataFrame, str]:
141141 return df , fname
142142
143143
144+ def make_test_seed (input_data : np .ndarray ):
145+ """Function call used to create Ain for testing"""
146+ mean_proj = np .mean (input_data , axis = 0 )
147+ return extract_binary_masks_from_structural_channel (mean_proj , gSig = 3 )[0 ]
148+
149+
144150def test_create_batch ():
145151 df , fname = _create_tmp_batch ()
146152
@@ -1336,3 +1342,183 @@ def test_cache():
13361342 assert hex (id (cnmf .cnmf_cache .get_cache ().iloc [- 1 ]["return_val" ])) == hex (
13371343 id (output )
13381344 )
1345+
1346+
1347+ def test_seeded_cnmf ():
1348+ """Test seeded CNNF (Ain)"""
1349+ set_parent_raw_data_path (vid_dir )
1350+ algo = "mcorr"
1351+
1352+ df , batch_path = _create_tmp_batch ()
1353+
1354+ batch_path = Path (batch_path )
1355+ batch_dir = batch_path .parent
1356+ batch_dir_canon = batch_dir .resolve ()
1357+
1358+ input_movie_path = get_datafile (algo )
1359+ print (input_movie_path )
1360+
1361+ df .caiman .add_item (
1362+ algo = algo ,
1363+ item_name = f"test-{ algo } " ,
1364+ input_movie_path = input_movie_path ,
1365+ params = test_params [algo ],
1366+ )
1367+
1368+ df .iloc [- 1 ].caiman .run ()
1369+ df = load_batch (batch_path )
1370+
1371+ assert df .iloc [- 1 ]["outputs" ]["success" ] is True
1372+ assert df .iloc [- 1 ]["outputs" ]["traceback" ] is None
1373+
1374+ # make seed
1375+ mcorr_output = df .iloc [- 1 ].mcorr .get_output ()
1376+ seed = make_test_seed (mcorr_output )
1377+ seed_path = seed_dir / "Ain_cnmf.npy"
1378+ np .save (seed_path , seed )
1379+
1380+ algo = "cnmf"
1381+ print ("Testing seeded cnmf" )
1382+ input_movie_path = df .iloc [- 1 ].mcorr .get_output_path ()
1383+ seeded_params = {
1384+ ** test_params [algo ],
1385+ "Ain_path" : seed_path ,
1386+ "refit" : False
1387+ }
1388+
1389+ df .caiman .add_item (
1390+ algo = algo ,
1391+ item_name = f"test-seeded-{ algo } " ,
1392+ input_movie_path = input_movie_path ,
1393+ params = seeded_params
1394+ )
1395+
1396+ assert df .iloc [- 1 ]["algo" ] == algo
1397+ assert df .iloc [- 1 ]["item_name" ] == f"test-seeded-{ algo } "
1398+ assert df .iloc [- 1 ]["params" ] == seeded_params
1399+ assert df .iloc [- 1 ]["outputs" ] is None
1400+ try :
1401+ UUID (df .iloc [- 1 ]["uuid" ])
1402+ except :
1403+ pytest .fail ("Something wrong with setting UUID for batch items" )
1404+ print ("cnmf input_movie_path:" , df .iloc [- 1 ]["input_movie_path" ])
1405+ assert batch_dir_canon .joinpath (df .iloc [- 1 ]["input_movie_path" ]) == input_movie_path
1406+
1407+ df .iloc [- 1 ].caiman .run ()
1408+
1409+ df = load_batch (batch_path )
1410+
1411+ with pd .option_context ("display.max_rows" , None , "display.max_columns" , None ):
1412+ print (df )
1413+
1414+ pprint (df .iloc [- 1 ]["outputs" ], width = - 1 )
1415+ print (df .iloc [- 1 ]["outputs" ]["traceback" ])
1416+
1417+ assert df .iloc [- 1 ]["outputs" ]["success" ] is True
1418+ assert df .iloc [- 1 ]["outputs" ]["traceback" ] is None
1419+
1420+ # test to check cnmf get_masks()
1421+ cnmf_spatial_masks = df .iloc [- 1 ].cnmf .get_masks ("good" )
1422+ cnmf_spatial_masks_actual = numpy .load (
1423+ ground_truths_dir .joinpath ("cnmf_seeded" , "spatial_masks.npy" )
1424+ )
1425+ numpy .testing .assert_array_equal (cnmf_spatial_masks , cnmf_spatial_masks_actual )
1426+
1427+ # test to check get_temporal()
1428+ cnmf_temporal_components = df .iloc [- 1 ].cnmf .get_temporal ("good" )
1429+ cnmf_temporal_components_actual = numpy .load (
1430+ ground_truths_dir .joinpath ("cnmf_seeded" , "temporal_components.npy" )
1431+ )
1432+ numpy .testing .assert_allclose (
1433+ cnmf_temporal_components , cnmf_temporal_components_actual , rtol = 1e-2 , atol = 1e-10
1434+ )
1435+
1436+
1437+ def test_seeded_cnmfe ():
1438+ set_parent_raw_data_path (vid_dir )
1439+
1440+ df , batch_path = _create_tmp_batch ()
1441+
1442+ batch_path = Path (batch_path )
1443+ batch_dir = batch_path .parent
1444+ batch_dir_canon = batch_dir .resolve ()
1445+
1446+ input_movie_path = get_datafile ("cnmfe" )
1447+ print (input_movie_path )
1448+ df .caiman .add_item (
1449+ algo = "mcorr" ,
1450+ item_name = "test-cnmfe-mcorr" ,
1451+ input_movie_path = input_movie_path ,
1452+ params = test_params ["mcorr" ],
1453+ )
1454+ df .iloc [- 1 ].caiman .run ()
1455+
1456+ df = load_batch (batch_path )
1457+
1458+ # Test if running seeded cnmfe works
1459+ # this seed is actually trash for CNMFE but just see if it's consistent
1460+ mcorr_output = df .iloc [- 1 ].mcorr .get_output ()
1461+ seed = make_test_seed (mcorr_output )
1462+ seed_path = seed_dir / "Ain_cnmfe.npy"
1463+ np .save (seed_path , seed )
1464+
1465+ print ("testing seeded cnmfe" )
1466+ algo = "cnmfe"
1467+ param_name = "cnmfe_full"
1468+ input_movie_path = df .iloc [0 ].mcorr .get_output_path ()
1469+ print (input_movie_path )
1470+ seeded_params = {
1471+ ** test_params [param_name ],
1472+ "Ain_path" : seed_path ,
1473+ "refit" : False
1474+ }
1475+
1476+ df .caiman .add_item (
1477+ algo = algo ,
1478+ item_name = f"test-seeded-{ algo } " ,
1479+ input_movie_path = input_movie_path ,
1480+ params = seeded_params ,
1481+ )
1482+
1483+ assert df .iloc [- 1 ]["algo" ] == algo
1484+ assert df .iloc [- 1 ]["item_name" ] == f"test-seeded-{ algo } "
1485+ assert df .iloc [- 1 ]["params" ] == seeded_params
1486+ assert df .iloc [- 1 ]["outputs" ] is None
1487+ try :
1488+ UUID (df .iloc [- 1 ]["uuid" ])
1489+ except :
1490+ pytest .fail ("Something wrong with setting UUID for batch items" )
1491+
1492+ assert (
1493+ batch_dir_canon .joinpath (df .iloc [- 1 ]["input_movie_path" ])
1494+ == batch_dir_canon .joinpath (df .iloc [0 ].mcorr .get_output_path ())
1495+ == df .paths .resolve (df .iloc [- 1 ]["input_movie_path" ])
1496+ )
1497+
1498+ df .iloc [- 1 ].caiman .run ()
1499+ df = load_batch (batch_path )
1500+
1501+ with pd .option_context ("display.max_rows" , None , "display.max_columns" , None ):
1502+ print (df )
1503+
1504+ pprint (df .iloc [- 1 ]["outputs" ], width = - 1 )
1505+ print (df .iloc [- 1 ]["outputs" ]["traceback" ])
1506+
1507+ assert df .iloc [- 1 ]["outputs" ]["success" ] is True
1508+ assert df .iloc [- 1 ]["outputs" ]["traceback" ] is None
1509+
1510+ # test to check cnmf get_masks()
1511+ cnmf_spatial_masks = df .iloc [- 1 ].cnmf .get_masks ("good" )
1512+ cnmf_spatial_masks_actual = numpy .load (
1513+ ground_truths_dir .joinpath ("cnmfe_seeded" , "spatial_masks.npy" )
1514+ )
1515+ numpy .testing .assert_array_equal (cnmf_spatial_masks , cnmf_spatial_masks_actual )
1516+
1517+ # test to check get_temporal()
1518+ cnmf_temporal_components = df .iloc [- 1 ].cnmf .get_temporal ("good" )
1519+ cnmf_temporal_components_actual = numpy .load (
1520+ ground_truths_dir .joinpath ("cnmfe_seeded" , "temporal_components.npy" )
1521+ )
1522+ numpy .testing .assert_allclose (
1523+ cnmf_temporal_components , cnmf_temporal_components_actual , rtol = 1e-2 , atol = 1e-10
1524+ )
0 commit comments