Skip to content

Commit 151318c

Browse files
committed
Refactor likelihoods test imports and formatting
Simplifies imports for likelihood modules by removing try/except fallbacks and directly importing modules. Cleans up formatting, improves code readability, and updates string quoting for consistency. No changes to test logic or functionality.
1 parent f50d9d0 commit 151318c

File tree

1 file changed

+84
-87
lines changed

1 file changed

+84
-87
lines changed

replay_trajectory_classification/tests/unit/test_likelihoods.py

Lines changed: 84 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,17 @@
11
# replay_trajectory_classification/tests/unit/test_likelihoods.py
22
from __future__ import annotations
33

4-
from typing import TYPE_CHECKING, Any
5-
64
import numpy as np
75
import pytest
86

7+
import replay_trajectory_classification.likelihoods.calcium_likelihood as calcium_likelihood
8+
import replay_trajectory_classification.likelihoods.multiunit_likelihood as multiunit_likelihood
9+
import replay_trajectory_classification.likelihoods.spiking_likelihood_glm as spiking_likelihood_glm
10+
import replay_trajectory_classification.likelihoods.spiking_likelihood_kde as spiking_likelihood_kde
11+
912
# Test imports for likelihood modules
1013
from replay_trajectory_classification.environments import Environment
1114

12-
if TYPE_CHECKING:
13-
pass
14-
15-
# Try importing likelihood functions with graceful fallbacks
16-
try:
17-
import replay_trajectory_classification.likelihoods.spiking_likelihood_glm as spiking_likelihood_glm
18-
except ImportError:
19-
spiking_likelihood_glm: Any = None
20-
21-
try:
22-
import replay_trajectory_classification.likelihoods.spiking_likelihood_kde as spiking_likelihood_kde
23-
except ImportError:
24-
spiking_likelihood_kde: Any = None
25-
26-
try:
27-
import replay_trajectory_classification.likelihoods.multiunit_likelihood as multiunit_likelihood
28-
except ImportError:
29-
multiunit_likelihood: Any = None
30-
31-
try:
32-
import replay_trajectory_classification.likelihoods.calcium_likelihood as calcium_likelihood
33-
except ImportError:
34-
calcium_likelihood: Any = None
35-
36-
3715
# ---------------------- Helpers ----------------------
3816

3917

@@ -68,7 +46,7 @@ def make_multiunit_data(n_electrodes=3, n_features=4, n_time=20):
6846

6947
for t in range(n_time):
7048
if n_spikes_per_time[t] < max_spikes:
71-
no_spike_indicator[t, n_spikes_per_time[t]:] = True
49+
no_spike_indicator[t, n_spikes_per_time[t] :] = True
7250

7351
data[f"electrode_{elec_id:02d}"] = {
7452
"marks": marks,
@@ -82,28 +60,25 @@ def make_multiunit_data(n_electrodes=3, n_features=4, n_time=20):
8260

8361

8462
@pytest.mark.skipif(
85-
spiking_likelihood_glm is None,
86-
reason="spiking_likelihood_glm module not available"
63+
spiking_likelihood_glm is None, reason="spiking_likelihood_glm module not available"
8764
)
8865
def test_spiking_likelihood_glm_fit_exists():
8966
"""Test that GLM likelihood fit function exists and is callable."""
90-
assert hasattr(spiking_likelihood_glm, 'estimate_place_fields')
67+
assert hasattr(spiking_likelihood_glm, "estimate_place_fields")
9168
assert callable(spiking_likelihood_glm.estimate_place_fields)
9269

9370

9471
@pytest.mark.skipif(
95-
spiking_likelihood_glm is None,
96-
reason="spiking_likelihood_glm module not available"
72+
spiking_likelihood_glm is None, reason="spiking_likelihood_glm module not available"
9773
)
9874
def test_spiking_likelihood_glm_estimate_exists():
9975
"""Test that GLM likelihood estimate function exists and is callable."""
100-
assert hasattr(spiking_likelihood_glm, 'estimate_spiking_likelihood')
76+
assert hasattr(spiking_likelihood_glm, "estimate_spiking_likelihood")
10177
assert callable(spiking_likelihood_glm.estimate_spiking_likelihood)
10278

10379

10480
@pytest.mark.skipif(
105-
spiking_likelihood_glm is None,
106-
reason="spiking_likelihood_glm module not available"
81+
spiking_likelihood_glm is None, reason="spiking_likelihood_glm module not available"
10782
)
10883
def test_spiking_likelihood_glm_basic_functionality():
10984
"""Test basic GLM likelihood functionality with synthetic data."""
@@ -118,10 +93,10 @@ def test_spiking_likelihood_glm_basic_functionality():
11893
position=position,
11994
spikes=spikes,
12095
place_bin_centers=environment.place_bin_centers_,
121-
place_bin_edges=getattr(environment, 'place_bin_edges_', None),
122-
edges=getattr(environment, 'edges_', None),
96+
place_bin_edges=getattr(environment, "place_bin_edges_", None),
97+
edges=getattr(environment, "edges_", None),
12398
is_track_interior=environment.is_track_interior_,
124-
is_track_boundary=getattr(environment, 'is_track_boundary_', None)
99+
is_track_boundary=getattr(environment, "is_track_boundary_", None),
125100
)
126101

127102
# Basic checks on results
@@ -139,28 +114,25 @@ def test_spiking_likelihood_glm_basic_functionality():
139114

140115

141116
@pytest.mark.skipif(
142-
spiking_likelihood_kde is None,
143-
reason="spiking_likelihood_kde module not available"
117+
spiking_likelihood_kde is None, reason="spiking_likelihood_kde module not available"
144118
)
145119
def test_spiking_likelihood_kde_fit_exists():
146120
"""Test that KDE likelihood fit function exists and is callable."""
147-
assert hasattr(spiking_likelihood_kde, 'estimate_place_fields_kde')
121+
assert hasattr(spiking_likelihood_kde, "estimate_place_fields_kde")
148122
assert callable(spiking_likelihood_kde.estimate_place_fields_kde)
149123

150124

151125
@pytest.mark.skipif(
152-
spiking_likelihood_kde is None,
153-
reason="spiking_likelihood_kde module not available"
126+
spiking_likelihood_kde is None, reason="spiking_likelihood_kde module not available"
154127
)
155128
def test_spiking_likelihood_kde_estimate_exists():
156129
"""Test that KDE likelihood estimate function exists and is callable."""
157-
assert hasattr(spiking_likelihood_kde, 'estimate_spiking_likelihood_kde')
130+
assert hasattr(spiking_likelihood_kde, "estimate_spiking_likelihood_kde")
158131
assert callable(spiking_likelihood_kde.estimate_spiking_likelihood_kde)
159132

160133

161134
@pytest.mark.skipif(
162-
spiking_likelihood_kde is None,
163-
reason="spiking_likelihood_kde module not available"
135+
spiking_likelihood_kde is None, reason="spiking_likelihood_kde module not available"
164136
)
165137
def test_spiking_likelihood_kde_basic_functionality():
166138
"""Test basic KDE likelihood functionality with synthetic data."""
@@ -171,12 +143,12 @@ def test_spiking_likelihood_kde_basic_functionality():
171143
position = make_simple_position(n_time=50, n_dims=1)
172144

173145
# Test fit function - try common function names
174-
if hasattr(spiking_likelihood_kde, 'estimate_place_fields_kde'):
146+
if hasattr(spiking_likelihood_kde, "estimate_place_fields_kde"):
175147
results = spiking_likelihood_kde.estimate_place_fields_kde(
176148
position=position,
177149
spikes=spikes,
178150
place_bin_centers=environment.place_bin_centers_,
179-
is_track_interior=environment.is_track_interior_
151+
is_track_interior=environment.is_track_interior_,
180152
)
181153
else:
182154
# Try alternative function names
@@ -195,28 +167,25 @@ def test_spiking_likelihood_kde_basic_functionality():
195167

196168

197169
@pytest.mark.skipif(
198-
multiunit_likelihood is None,
199-
reason="multiunit_likelihood module not available"
170+
multiunit_likelihood is None, reason="multiunit_likelihood module not available"
200171
)
201172
def test_multiunit_likelihood_fit_exists():
202173
"""Test that multiunit likelihood fit function exists and is callable."""
203-
assert hasattr(multiunit_likelihood, 'fit_multiunit_likelihood')
174+
assert hasattr(multiunit_likelihood, "fit_multiunit_likelihood")
204175
assert callable(multiunit_likelihood.fit_multiunit_likelihood)
205176

206177

207178
@pytest.mark.skipif(
208-
multiunit_likelihood is None,
209-
reason="multiunit_likelihood module not available"
179+
multiunit_likelihood is None, reason="multiunit_likelihood module not available"
210180
)
211181
def test_multiunit_likelihood_estimate_exists():
212182
"""Test that multiunit likelihood estimate function exists and is callable."""
213-
assert hasattr(multiunit_likelihood, 'estimate_multiunit_likelihood')
183+
assert hasattr(multiunit_likelihood, "estimate_multiunit_likelihood")
214184
assert callable(multiunit_likelihood.estimate_multiunit_likelihood)
215185

216186

217187
@pytest.mark.skipif(
218-
multiunit_likelihood is None,
219-
reason="multiunit_likelihood module not available"
188+
multiunit_likelihood is None, reason="multiunit_likelihood module not available"
220189
)
221190
def test_multiunit_likelihood_basic_functionality():
222191
"""Test basic multiunit likelihood functionality with synthetic data."""
@@ -228,12 +197,16 @@ def test_multiunit_likelihood_basic_functionality():
228197

229198
# Convert multiunit dict to 3D array format
230199
n_electrodes = len(multiunit_data)
231-
n_features = list(multiunit_data.values())[0]["marks"].shape[2]
200+
list(multiunit_data.values())[0]["marks"].shape[2]
232201
max_marks = list(multiunit_data.values())[0]["marks"].shape[1]
233202

234203
multiunit_3d = np.full((50, max_marks, n_electrodes), np.nan)
235-
for elec_idx, (electrode_id, electrode_data) in enumerate(multiunit_data.items()):
236-
multiunit_3d[:, :, elec_idx] = electrode_data["marks"][:, :, 0] # Use first feature
204+
for elec_idx, (electrode_id, electrode_data) in enumerate(
205+
multiunit_data.items()
206+
):
207+
multiunit_3d[:, :, elec_idx] = electrode_data["marks"][
208+
:, :, 0
209+
] # Use first feature
237210
no_spike_mask = electrode_data["no_spike_indicator"]
238211
multiunit_3d[no_spike_mask, elec_idx] = np.nan
239212

@@ -243,10 +216,10 @@ def test_multiunit_likelihood_basic_functionality():
243216
multiunits=multiunit_3d,
244217
place_bin_centers=environment.place_bin_centers_,
245218
is_track_interior=environment.is_track_interior_,
246-
is_track_boundary=getattr(environment, 'is_track_boundary_', None),
247-
edges=getattr(environment, 'edges_', None),
219+
is_track_boundary=getattr(environment, "is_track_boundary_", None),
220+
edges=getattr(environment, "edges_", None),
248221
mark_std=24.0,
249-
position_std=6.0
222+
position_std=6.0,
250223
)
251224

252225
# Basic checks on results
@@ -262,13 +235,12 @@ def test_multiunit_likelihood_basic_functionality():
262235

263236

264237
@pytest.mark.skipif(
265-
calcium_likelihood is None,
266-
reason="calcium_likelihood module not available"
238+
calcium_likelihood is None, reason="calcium_likelihood module not available"
267239
)
268240
def test_calcium_likelihood_fit_exists():
269241
"""Test that calcium likelihood fit function exists and is callable."""
270242
# Look for common function names in calcium likelihood module
271-
fit_funcs = [name for name in dir(calcium_likelihood) if 'fit' in name.lower()]
243+
fit_funcs = [name for name in dir(calcium_likelihood) if "fit" in name.lower()]
272244
assert len(fit_funcs) > 0, "No fit functions found in calcium_likelihood module"
273245

274246
for func_name in fit_funcs:
@@ -280,14 +252,17 @@ def test_calcium_likelihood_fit_exists():
280252

281253

282254
@pytest.mark.skipif(
283-
calcium_likelihood is None,
284-
reason="calcium_likelihood module not available"
255+
calcium_likelihood is None, reason="calcium_likelihood module not available"
285256
)
286257
def test_calcium_likelihood_estimate_exists():
287258
"""Test that calcium likelihood estimate function exists and is callable."""
288259
# Look for common function names in calcium likelihood module
289-
estimate_funcs = [name for name in dir(calcium_likelihood) if 'estimate' in name.lower()]
290-
assert len(estimate_funcs) > 0, "No estimate functions found in calcium_likelihood module"
260+
estimate_funcs = [
261+
name for name in dir(calcium_likelihood) if "estimate" in name.lower()
262+
]
263+
assert (
264+
len(estimate_funcs) > 0
265+
), "No estimate functions found in calcium_likelihood module"
291266

292267

293268
# ---------------------- General Likelihood Interface Tests ----------------------
@@ -298,31 +273,39 @@ def test_likelihood_modules_importable():
298273
# This test verifies basic import functionality
299274
try:
300275
from replay_trajectory_classification import likelihoods
276+
301277
assert likelihoods is not None
302278
except ImportError:
303279
pytest.fail("Could not import likelihoods subpackage")
304280

305281

306-
@pytest.mark.parametrize("module_name", [
307-
"spiking_likelihood_glm",
308-
"spiking_likelihood_kde",
309-
"multiunit_likelihood",
310-
"calcium_likelihood"
311-
])
282+
@pytest.mark.parametrize(
283+
"module_name",
284+
[
285+
"spiking_likelihood_glm",
286+
"spiking_likelihood_kde",
287+
"multiunit_likelihood",
288+
"calcium_likelihood",
289+
],
290+
)
312291
def test_likelihood_module_structure(module_name):
313292
"""Test that likelihood modules have expected structure."""
314293
try:
315294
module = __import__(
316295
f"replay_trajectory_classification.likelihoods.{module_name}",
317-
fromlist=[module_name]
296+
fromlist=[module_name],
318297
)
319298

320299
# Should have at least some public functions
321-
public_attrs = [name for name in dir(module) if not name.startswith('_')]
300+
public_attrs = [name for name in dir(module) if not name.startswith("_")]
322301
assert len(public_attrs) > 0, f"No public attributes in {module_name}"
323302

324303
# Should have at least one callable
325-
callables = [getattr(module, name) for name in public_attrs if callable(getattr(module, name))]
304+
callables = [
305+
getattr(module, name)
306+
for name in public_attrs
307+
if callable(getattr(module, name))
308+
]
326309
assert len(callables) > 0, f"No callable functions in {module_name}"
327310

328311
except ImportError:
@@ -343,8 +326,8 @@ def test_likelihood_functions_handle_edge_cases():
343326

344327
# Test each available likelihood module
345328
modules_to_test = [
346-
(spiking_likelihood_glm, 'fit_spiking_likelihood_glm'),
347-
(spiking_likelihood_kde, 'fit_spiking_likelihood_kde'),
329+
(spiking_likelihood_glm, "fit_spiking_likelihood_glm"),
330+
(spiking_likelihood_kde, "fit_spiking_likelihood_kde"),
348331
]
349332

350333
for module, func_name in modules_to_test:
@@ -355,7 +338,7 @@ def test_likelihood_functions_handle_edge_cases():
355338
result = func(empty_position, empty_spikes, environment)
356339
# If it returns something, it should be reasonable
357340
if result is not None:
358-
assert not (hasattr(result, 'shape') and result.shape == (0,))
341+
assert not (hasattr(result, "shape") and result.shape == (0,))
359342
except (ValueError, RuntimeError, ZeroDivisionError):
360343
# These are acceptable exceptions for edge cases
361344
continue
@@ -370,9 +353,21 @@ def test_likelihood_functions_handle_edge_cases():
370353
def test_fit_estimate_consistency():
371354
"""Test that fit and estimate functions have consistent interfaces."""
372355
modules_to_test = [
373-
(spiking_likelihood_glm, 'estimate_place_fields', 'estimate_spiking_likelihood'),
374-
(spiking_likelihood_kde, 'estimate_place_fields_kde', 'estimate_spiking_likelihood_kde'),
375-
(multiunit_likelihood, 'fit_multiunit_likelihood', 'estimate_multiunit_likelihood'),
356+
(
357+
spiking_likelihood_glm,
358+
"estimate_place_fields",
359+
"estimate_spiking_likelihood",
360+
),
361+
(
362+
spiking_likelihood_kde,
363+
"estimate_place_fields_kde",
364+
"estimate_spiking_likelihood_kde",
365+
),
366+
(
367+
multiunit_likelihood,
368+
"fit_multiunit_likelihood",
369+
"estimate_multiunit_likelihood",
370+
),
376371
]
377372

378373
for module, fit_func_name, estimate_func_name in modules_to_test:
@@ -383,10 +378,12 @@ def test_fit_estimate_consistency():
383378

384379
if has_fit or has_estimate:
385380
# If one exists, both should exist for consistency
386-
assert has_fit and has_estimate, f"Module {module.__name__} should have both fit and estimate functions"
381+
assert (
382+
has_fit and has_estimate
383+
), f"Module {module.__name__} should have both fit and estimate functions"
387384

388385
# Both should be callable
389386
fit_func = getattr(module, fit_func_name)
390387
estimate_func = getattr(module, estimate_func_name)
391388
assert callable(fit_func)
392-
assert callable(estimate_func)
389+
assert callable(estimate_func)

0 commit comments

Comments
 (0)