11# replay_trajectory_classification/tests/unit/test_likelihoods.py
22from __future__ import annotations
33
4- from typing import TYPE_CHECKING , Any
5-
64import numpy as np
75import pytest
86
7+ import replay_trajectory_classification .likelihoods .calcium_likelihood as calcium_likelihood
8+ import replay_trajectory_classification .likelihoods .multiunit_likelihood as multiunit_likelihood
9+ import replay_trajectory_classification .likelihoods .spiking_likelihood_glm as spiking_likelihood_glm
10+ import replay_trajectory_classification .likelihoods .spiking_likelihood_kde as spiking_likelihood_kde
11+
912# Test imports for likelihood modules
1013from replay_trajectory_classification .environments import Environment
1114
12- if TYPE_CHECKING :
13- pass
14-
15- # Try importing likelihood functions with graceful fallbacks
16- try :
17- import replay_trajectory_classification .likelihoods .spiking_likelihood_glm as spiking_likelihood_glm
18- except ImportError :
19- spiking_likelihood_glm : Any = None
20-
21- try :
22- import replay_trajectory_classification .likelihoods .spiking_likelihood_kde as spiking_likelihood_kde
23- except ImportError :
24- spiking_likelihood_kde : Any = None
25-
26- try :
27- import replay_trajectory_classification .likelihoods .multiunit_likelihood as multiunit_likelihood
28- except ImportError :
29- multiunit_likelihood : Any = None
30-
31- try :
32- import replay_trajectory_classification .likelihoods .calcium_likelihood as calcium_likelihood
33- except ImportError :
34- calcium_likelihood : Any = None
35-
36-
3715# ---------------------- Helpers ----------------------
3816
3917
@@ -68,7 +46,7 @@ def make_multiunit_data(n_electrodes=3, n_features=4, n_time=20):
6846
6947 for t in range (n_time ):
7048 if n_spikes_per_time [t ] < max_spikes :
71- no_spike_indicator [t , n_spikes_per_time [t ]:] = True
49+ no_spike_indicator [t , n_spikes_per_time [t ] :] = True
7250
7351 data [f"electrode_{ elec_id :02d} " ] = {
7452 "marks" : marks ,
@@ -82,28 +60,25 @@ def make_multiunit_data(n_electrodes=3, n_features=4, n_time=20):
8260
8361
8462@pytest .mark .skipif (
85- spiking_likelihood_glm is None ,
86- reason = "spiking_likelihood_glm module not available"
63+ spiking_likelihood_glm is None , reason = "spiking_likelihood_glm module not available"
8764)
8865def test_spiking_likelihood_glm_fit_exists ():
8966 """Test that GLM likelihood fit function exists and is callable."""
90- assert hasattr (spiking_likelihood_glm , ' estimate_place_fields' )
67+ assert hasattr (spiking_likelihood_glm , " estimate_place_fields" )
9168 assert callable (spiking_likelihood_glm .estimate_place_fields )
9269
9370
9471@pytest .mark .skipif (
95- spiking_likelihood_glm is None ,
96- reason = "spiking_likelihood_glm module not available"
72+ spiking_likelihood_glm is None , reason = "spiking_likelihood_glm module not available"
9773)
9874def test_spiking_likelihood_glm_estimate_exists ():
9975 """Test that GLM likelihood estimate function exists and is callable."""
100- assert hasattr (spiking_likelihood_glm , ' estimate_spiking_likelihood' )
76+ assert hasattr (spiking_likelihood_glm , " estimate_spiking_likelihood" )
10177 assert callable (spiking_likelihood_glm .estimate_spiking_likelihood )
10278
10379
10480@pytest .mark .skipif (
105- spiking_likelihood_glm is None ,
106- reason = "spiking_likelihood_glm module not available"
81+ spiking_likelihood_glm is None , reason = "spiking_likelihood_glm module not available"
10782)
10883def test_spiking_likelihood_glm_basic_functionality ():
10984 """Test basic GLM likelihood functionality with synthetic data."""
@@ -118,10 +93,10 @@ def test_spiking_likelihood_glm_basic_functionality():
11893 position = position ,
11994 spikes = spikes ,
12095 place_bin_centers = environment .place_bin_centers_ ,
121- place_bin_edges = getattr (environment , ' place_bin_edges_' , None ),
122- edges = getattr (environment , ' edges_' , None ),
96+ place_bin_edges = getattr (environment , " place_bin_edges_" , None ),
97+ edges = getattr (environment , " edges_" , None ),
12398 is_track_interior = environment .is_track_interior_ ,
124- is_track_boundary = getattr (environment , ' is_track_boundary_' , None )
99+ is_track_boundary = getattr (environment , " is_track_boundary_" , None ),
125100 )
126101
127102 # Basic checks on results
@@ -139,28 +114,25 @@ def test_spiking_likelihood_glm_basic_functionality():
139114
140115
141116@pytest .mark .skipif (
142- spiking_likelihood_kde is None ,
143- reason = "spiking_likelihood_kde module not available"
117+ spiking_likelihood_kde is None , reason = "spiking_likelihood_kde module not available"
144118)
145119def test_spiking_likelihood_kde_fit_exists ():
146120 """Test that KDE likelihood fit function exists and is callable."""
147- assert hasattr (spiking_likelihood_kde , ' estimate_place_fields_kde' )
121+ assert hasattr (spiking_likelihood_kde , " estimate_place_fields_kde" )
148122 assert callable (spiking_likelihood_kde .estimate_place_fields_kde )
149123
150124
151125@pytest .mark .skipif (
152- spiking_likelihood_kde is None ,
153- reason = "spiking_likelihood_kde module not available"
126+ spiking_likelihood_kde is None , reason = "spiking_likelihood_kde module not available"
154127)
155128def test_spiking_likelihood_kde_estimate_exists ():
156129 """Test that KDE likelihood estimate function exists and is callable."""
157- assert hasattr (spiking_likelihood_kde , ' estimate_spiking_likelihood_kde' )
130+ assert hasattr (spiking_likelihood_kde , " estimate_spiking_likelihood_kde" )
158131 assert callable (spiking_likelihood_kde .estimate_spiking_likelihood_kde )
159132
160133
161134@pytest .mark .skipif (
162- spiking_likelihood_kde is None ,
163- reason = "spiking_likelihood_kde module not available"
135+ spiking_likelihood_kde is None , reason = "spiking_likelihood_kde module not available"
164136)
165137def test_spiking_likelihood_kde_basic_functionality ():
166138 """Test basic KDE likelihood functionality with synthetic data."""
@@ -171,12 +143,12 @@ def test_spiking_likelihood_kde_basic_functionality():
171143 position = make_simple_position (n_time = 50 , n_dims = 1 )
172144
173145 # Test fit function - try common function names
174- if hasattr (spiking_likelihood_kde , ' estimate_place_fields_kde' ):
146+ if hasattr (spiking_likelihood_kde , " estimate_place_fields_kde" ):
175147 results = spiking_likelihood_kde .estimate_place_fields_kde (
176148 position = position ,
177149 spikes = spikes ,
178150 place_bin_centers = environment .place_bin_centers_ ,
179- is_track_interior = environment .is_track_interior_
151+ is_track_interior = environment .is_track_interior_ ,
180152 )
181153 else :
182154 # Try alternative function names
@@ -195,28 +167,25 @@ def test_spiking_likelihood_kde_basic_functionality():
195167
196168
197169@pytest .mark .skipif (
198- multiunit_likelihood is None ,
199- reason = "multiunit_likelihood module not available"
170+ multiunit_likelihood is None , reason = "multiunit_likelihood module not available"
200171)
201172def test_multiunit_likelihood_fit_exists ():
202173 """Test that multiunit likelihood fit function exists and is callable."""
203- assert hasattr (multiunit_likelihood , ' fit_multiunit_likelihood' )
174+ assert hasattr (multiunit_likelihood , " fit_multiunit_likelihood" )
204175 assert callable (multiunit_likelihood .fit_multiunit_likelihood )
205176
206177
207178@pytest .mark .skipif (
208- multiunit_likelihood is None ,
209- reason = "multiunit_likelihood module not available"
179+ multiunit_likelihood is None , reason = "multiunit_likelihood module not available"
210180)
211181def test_multiunit_likelihood_estimate_exists ():
212182 """Test that multiunit likelihood estimate function exists and is callable."""
213- assert hasattr (multiunit_likelihood , ' estimate_multiunit_likelihood' )
183+ assert hasattr (multiunit_likelihood , " estimate_multiunit_likelihood" )
214184 assert callable (multiunit_likelihood .estimate_multiunit_likelihood )
215185
216186
217187@pytest .mark .skipif (
218- multiunit_likelihood is None ,
219- reason = "multiunit_likelihood module not available"
188+ multiunit_likelihood is None , reason = "multiunit_likelihood module not available"
220189)
221190def test_multiunit_likelihood_basic_functionality ():
222191 """Test basic multiunit likelihood functionality with synthetic data."""
@@ -228,12 +197,16 @@ def test_multiunit_likelihood_basic_functionality():
228197
229198 # Convert multiunit dict to 3D array format
230199 n_electrodes = len (multiunit_data )
231- n_features = list (multiunit_data .values ())[0 ]["marks" ].shape [2 ]
200+ list (multiunit_data .values ())[0 ]["marks" ].shape [2 ]
232201 max_marks = list (multiunit_data .values ())[0 ]["marks" ].shape [1 ]
233202
234203 multiunit_3d = np .full ((50 , max_marks , n_electrodes ), np .nan )
235- for elec_idx , (electrode_id , electrode_data ) in enumerate (multiunit_data .items ()):
236- multiunit_3d [:, :, elec_idx ] = electrode_data ["marks" ][:, :, 0 ] # Use first feature
204+ for elec_idx , (electrode_id , electrode_data ) in enumerate (
205+ multiunit_data .items ()
206+ ):
207+ multiunit_3d [:, :, elec_idx ] = electrode_data ["marks" ][
208+ :, :, 0
209+ ] # Use first feature
237210 no_spike_mask = electrode_data ["no_spike_indicator" ]
238211 multiunit_3d [no_spike_mask , elec_idx ] = np .nan
239212
@@ -243,10 +216,10 @@ def test_multiunit_likelihood_basic_functionality():
243216 multiunits = multiunit_3d ,
244217 place_bin_centers = environment .place_bin_centers_ ,
245218 is_track_interior = environment .is_track_interior_ ,
246- is_track_boundary = getattr (environment , ' is_track_boundary_' , None ),
247- edges = getattr (environment , ' edges_' , None ),
219+ is_track_boundary = getattr (environment , " is_track_boundary_" , None ),
220+ edges = getattr (environment , " edges_" , None ),
248221 mark_std = 24.0 ,
249- position_std = 6.0
222+ position_std = 6.0 ,
250223 )
251224
252225 # Basic checks on results
@@ -262,13 +235,12 @@ def test_multiunit_likelihood_basic_functionality():
262235
263236
264237@pytest .mark .skipif (
265- calcium_likelihood is None ,
266- reason = "calcium_likelihood module not available"
238+ calcium_likelihood is None , reason = "calcium_likelihood module not available"
267239)
268240def test_calcium_likelihood_fit_exists ():
269241 """Test that calcium likelihood fit function exists and is callable."""
270242 # Look for common function names in calcium likelihood module
271- fit_funcs = [name for name in dir (calcium_likelihood ) if ' fit' in name .lower ()]
243+ fit_funcs = [name for name in dir (calcium_likelihood ) if " fit" in name .lower ()]
272244 assert len (fit_funcs ) > 0 , "No fit functions found in calcium_likelihood module"
273245
274246 for func_name in fit_funcs :
@@ -280,14 +252,17 @@ def test_calcium_likelihood_fit_exists():
280252
281253
282254@pytest .mark .skipif (
283- calcium_likelihood is None ,
284- reason = "calcium_likelihood module not available"
255+ calcium_likelihood is None , reason = "calcium_likelihood module not available"
285256)
286257def test_calcium_likelihood_estimate_exists ():
287258 """Test that calcium likelihood estimate function exists and is callable."""
288259 # Look for common function names in calcium likelihood module
289- estimate_funcs = [name for name in dir (calcium_likelihood ) if 'estimate' in name .lower ()]
290- assert len (estimate_funcs ) > 0 , "No estimate functions found in calcium_likelihood module"
260+ estimate_funcs = [
261+ name for name in dir (calcium_likelihood ) if "estimate" in name .lower ()
262+ ]
263+ assert (
264+ len (estimate_funcs ) > 0
265+ ), "No estimate functions found in calcium_likelihood module"
291266
292267
293268# ---------------------- General Likelihood Interface Tests ----------------------
@@ -298,31 +273,39 @@ def test_likelihood_modules_importable():
298273 # This test verifies basic import functionality
299274 try :
300275 from replay_trajectory_classification import likelihoods
276+
301277 assert likelihoods is not None
302278 except ImportError :
303279 pytest .fail ("Could not import likelihoods subpackage" )
304280
305281
306- @pytest .mark .parametrize ("module_name" , [
307- "spiking_likelihood_glm" ,
308- "spiking_likelihood_kde" ,
309- "multiunit_likelihood" ,
310- "calcium_likelihood"
311- ])
282+ @pytest .mark .parametrize (
283+ "module_name" ,
284+ [
285+ "spiking_likelihood_glm" ,
286+ "spiking_likelihood_kde" ,
287+ "multiunit_likelihood" ,
288+ "calcium_likelihood" ,
289+ ],
290+ )
312291def test_likelihood_module_structure (module_name ):
313292 """Test that likelihood modules have expected structure."""
314293 try :
315294 module = __import__ (
316295 f"replay_trajectory_classification.likelihoods.{ module_name } " ,
317- fromlist = [module_name ]
296+ fromlist = [module_name ],
318297 )
319298
320299 # Should have at least some public functions
321- public_attrs = [name for name in dir (module ) if not name .startswith ('_' )]
300+ public_attrs = [name for name in dir (module ) if not name .startswith ("_" )]
322301 assert len (public_attrs ) > 0 , f"No public attributes in { module_name } "
323302
324303 # Should have at least one callable
325- callables = [getattr (module , name ) for name in public_attrs if callable (getattr (module , name ))]
304+ callables = [
305+ getattr (module , name )
306+ for name in public_attrs
307+ if callable (getattr (module , name ))
308+ ]
326309 assert len (callables ) > 0 , f"No callable functions in { module_name } "
327310
328311 except ImportError :
@@ -343,8 +326,8 @@ def test_likelihood_functions_handle_edge_cases():
343326
344327 # Test each available likelihood module
345328 modules_to_test = [
346- (spiking_likelihood_glm , ' fit_spiking_likelihood_glm' ),
347- (spiking_likelihood_kde , ' fit_spiking_likelihood_kde' ),
329+ (spiking_likelihood_glm , " fit_spiking_likelihood_glm" ),
330+ (spiking_likelihood_kde , " fit_spiking_likelihood_kde" ),
348331 ]
349332
350333 for module , func_name in modules_to_test :
@@ -355,7 +338,7 @@ def test_likelihood_functions_handle_edge_cases():
355338 result = func (empty_position , empty_spikes , environment )
356339 # If it returns something, it should be reasonable
357340 if result is not None :
358- assert not (hasattr (result , ' shape' ) and result .shape == (0 ,))
341+ assert not (hasattr (result , " shape" ) and result .shape == (0 ,))
359342 except (ValueError , RuntimeError , ZeroDivisionError ):
360343 # These are acceptable exceptions for edge cases
361344 continue
@@ -370,9 +353,21 @@ def test_likelihood_functions_handle_edge_cases():
370353def test_fit_estimate_consistency ():
371354 """Test that fit and estimate functions have consistent interfaces."""
372355 modules_to_test = [
373- (spiking_likelihood_glm , 'estimate_place_fields' , 'estimate_spiking_likelihood' ),
374- (spiking_likelihood_kde , 'estimate_place_fields_kde' , 'estimate_spiking_likelihood_kde' ),
375- (multiunit_likelihood , 'fit_multiunit_likelihood' , 'estimate_multiunit_likelihood' ),
356+ (
357+ spiking_likelihood_glm ,
358+ "estimate_place_fields" ,
359+ "estimate_spiking_likelihood" ,
360+ ),
361+ (
362+ spiking_likelihood_kde ,
363+ "estimate_place_fields_kde" ,
364+ "estimate_spiking_likelihood_kde" ,
365+ ),
366+ (
367+ multiunit_likelihood ,
368+ "fit_multiunit_likelihood" ,
369+ "estimate_multiunit_likelihood" ,
370+ ),
376371 ]
377372
378373 for module , fit_func_name , estimate_func_name in modules_to_test :
@@ -383,10 +378,12 @@ def test_fit_estimate_consistency():
383378
384379 if has_fit or has_estimate :
385380 # If one exists, both should exist for consistency
386- assert has_fit and has_estimate , f"Module { module .__name__ } should have both fit and estimate functions"
381+ assert (
382+ has_fit and has_estimate
383+ ), f"Module { module .__name__ } should have both fit and estimate functions"
387384
388385 # Both should be callable
389386 fit_func = getattr (module , fit_func_name )
390387 estimate_func = getattr (module , estimate_func_name )
391388 assert callable (fit_func )
392- assert callable (estimate_func )
389+ assert callable (estimate_func )
0 commit comments