Skip to content

Commit 2bb13a3

Browse files
edenoclaude
andcommitted
test: expand property tests for probability distributions
Add 3 new property-based tests that verify decoder posteriors maintain critical mathematical invariants: - test_posterior_probabilities_sum_to_one: Verifies posterior distributions sum to 1.0 across spatial dimension (state_bins) - test_posteriors_nonnegative_and_bounded: Verifies all posterior values are in [0, 1] range - test_log_probabilities_finite: Verifies log probabilities are finite or -inf (no NaN values) All tests use Hypothesis with 10 randomized examples, ClusterlessDecoder with RandomWalk transition, full simulation data (35K samples), and small prediction window (10 time bins) for speed. Test results: 13/13 property tests pass in ~24 seconds total. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent e635c7f commit 2bb13a3

File tree

3 files changed

+227
-6
lines changed

3 files changed

+227
-6
lines changed

.claude/skills/scientific-tdd/skill.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@ Pragmatic test-driven development for scientific code: write tests first for new
1818
## When to Use This Skill
1919

2020
**MUST use for:**
21+
2122
- New features or algorithms
2223
- Complex modifications to existing code
2324
- Adding new mathematical models
2425
- Implementing new likelihood functions or state transitions
2526

2627
**Can skip test-first for:**
28+
2729
- Simple bug fixes where existing tests already cover the behavior
2830
- Documentation changes
2931
- Refactoring with existing comprehensive tests (use safe-refactoring instead)
@@ -41,6 +43,7 @@ Scientific TDD Progress:
4143
- [ ] Run test to confirm GREEN (passes)
4244
- [ ] Run full test suite (check for regressions)
4345
- [ ] Run numerical validation if mathematical code changed
46+
- [ ] Run code-reviewer agent (and/or ux-reviewer when appropriate)
4447
- [ ] Refactor if needed (keep tests green)
4548
- [ ] Commit with descriptive message
4649
```
@@ -57,6 +60,7 @@ Before writing new tests, understand current state:
5760
- Identify what needs to change
5861

5962
**Commands:**
63+
6064
```bash
6165
# Find relevant tests
6266
pytest --collect-only -q | grep <relevant_term>
@@ -70,6 +74,7 @@ pytest --collect-only -q | grep <relevant_term>
7074
Write test that captures desired behavior:
7175

7276
**Test Structure:**
77+
7378
```python
7479
def test_descriptive_name_of_behavior():
7580
"""Test that [specific behavior] works correctly.
@@ -88,6 +93,7 @@ def test_descriptive_name_of_behavior():
8893
```
8994

9095
**For mathematical code, verify:**
96+
9197
- Correct output shapes
9298
- Mathematical invariants (probabilities sum to 1, matrices are stochastic)
9399
- Expected numerical values (with appropriate tolerances)
@@ -115,6 +121,7 @@ Write simplest code that makes test pass:
115121
- Use existing patterns from codebase
116122

117123
**For scientific code:**
124+
118125
- Maintain numerical stability
119126
- Use JAX operations where appropriate
120127
- Follow existing conventions for shapes and broadcasting
@@ -146,11 +153,13 @@ Check for regressions:
146153
If you modified mathematical/algorithmic code:
147154

148155
**Use numerical-validation skill:**
156+
149157
```
150158
@numerical-validation
151159
```
152160

153161
This verifies:
162+
154163
- Mathematical invariants still hold
155164
- Property-based tests pass
156165
- Golden regression tests pass
@@ -165,6 +174,7 @@ If code can be improved while keeping tests green:
165174
- Optimize performance (but verify numerics don't change)
166175

167176
**After each refactor:**
177+
168178
```bash
169179
/Users/edeno/miniconda3/envs/non_local_detector/bin/pytest -v
170180
```
@@ -210,6 +220,7 @@ Co-Authored-By: Claude <[email protected]>"
210220
## Red Flags
211221

212222
**Don't:**
223+
213224
- Write implementation before test (except for documented bug fixes)
214225
- Skip running test to see it fail
215226
- Add untested code "for future use"
@@ -218,6 +229,7 @@ Co-Authored-By: Claude <[email protected]>"
218229
- Skip numerical validation for mathematical code
219230

220231
**Do:**
232+
221233
- Write descriptive test names
222234
- Test one behavior per test
223235
- Use appropriate numerical tolerances (1e-10 for probabilities)

docs/TASKS.md

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,26 @@
9191

9292
## Phase 3: Property Test Enhancement
9393

94-
### Task 3.1: Expand Probability Distribution Properties
94+
### Task 3.1: Expand Probability Distribution Properties
9595

96-
- [ ] Review existing `test_probability_properties.py`
97-
- [ ] Add `test_posterior_probabilities_sum_to_one()` property
98-
- [ ] Add `test_posteriors_nonnegative_and_bounded()` property
99-
- [ ] Add `test_log_probabilities_finite()` property
100-
- [ ] Run property tests with hypothesis statistics
96+
- [x] Review existing `test_probability_properties.py`
97+
- [x] Add `test_posterior_probabilities_sum_to_one()` property
98+
- [x] Add `test_posteriors_nonnegative_and_bounded()` property
99+
- [x] Add `test_log_probabilities_finite()` property
100+
- [x] Run property tests - all 13 tests pass (10 original + 3 new)
101101
- [ ] Commit: "test: expand property tests for probability distributions"
102102

103+
**Implementation Notes:**
104+
105+
- Added 3 new property tests that verify decoder posteriors maintain mathematical invariants
106+
- Tests use RandomWalk transition with full simulation data (n_runs=3, all 35000 samples for training)
107+
- Only decode 10 time bins for speed (tests run in ~7-8 seconds each)
108+
- Key learnings:
109+
- RandomWalk requires substantial training data to build position bins (100 samples insufficient)
110+
- Decoder uses "state_bins" dimension name, not "position"
111+
- Must use `infer_track_interior=True` (default) for proper bin creation
112+
- All tests verify critical invariants: posteriors sum to 1, values in [0,1], log values finite
113+
103114
### Task 3.2: Add Transition Matrix Properties
104115

105116
- [ ] Add `test_transition_matrix_rows_sum_to_one()` to `test_hmm_invariants.py`

src/non_local_detector/tests/properties/test_probability_properties.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,15 @@
1212
from hypothesis import assume, given, settings
1313
from hypothesis import strategies as st
1414

15+
from non_local_detector.continuous_state_transitions import RandomWalk
1516
from non_local_detector.core import (
1617
_condition_on,
1718
_divide_safe,
1819
_normalize,
1920
_safe_log,
2021
)
22+
from non_local_detector.models.decoder import ClusterlessDecoder
23+
from non_local_detector.simulate.clusterless_simulation import make_simulated_run_data
2124

2225

2326
# Custom strategies for probability distributions
@@ -223,3 +226,198 @@ def test_normalize_scales_correctly(self, dist, scale_factor):
223226
normalized_scaled, _ = _normalize(jnp.asarray(scaled))
224227

225228
assert jnp.allclose(normalized_original, normalized_scaled, rtol=1e-6)
229+
230+
@settings(deadline=5000, max_examples=10) # Decoder tests are slower
231+
@given(st.integers(min_value=42, max_value=9999))
232+
def test_posterior_probabilities_sum_to_one(self, seed):
233+
"""Property: decoder posteriors sum to 1 across position dimension."""
234+
# Generate simulation
235+
# NOTE: n_runs must be >= 3 to create proper 2D position data
236+
# NOTE: Need substantial data for RandomWalk to build proper position bins
237+
sim = make_simulated_run_data(
238+
n_tetrodes=2,
239+
place_field_means=np.arange(0, 80, 20), # 4 neurons
240+
sampling_frequency=500,
241+
n_runs=3, # Multiple runs to ensure 2D position array
242+
seed=seed,
243+
)
244+
245+
# Use 70/30 train/test split on all data
246+
n_encode = int(0.7 * len(sim.position_time))
247+
is_training = np.ones(len(sim.position_time), dtype=bool)
248+
is_training[n_encode:] = False
249+
250+
decoder = ClusterlessDecoder(
251+
clusterless_algorithm="clusterless_kde",
252+
clusterless_algorithm_params={
253+
"position_std": 6.0,
254+
"waveform_std": 24.0,
255+
"block_size": 50,
256+
},
257+
continuous_transition_types=[[RandomWalk(movement_var=25.0)]],
258+
)
259+
260+
decoder.fit(
261+
sim.position_time,
262+
sim.position,
263+
sim.spike_times,
264+
sim.spike_waveform_features,
265+
is_training=is_training,
266+
)
267+
268+
# Predict on small test set (10 time bins only for speed)
269+
test_start_idx = n_encode
270+
test_end_idx = min(n_encode + 10, len(sim.position_time))
271+
results = decoder.predict(
272+
spike_times=[
273+
st[
274+
(st >= sim.position_time[test_start_idx])
275+
& (st < sim.position_time[test_end_idx])
276+
]
277+
for st in sim.spike_times
278+
],
279+
spike_waveform_features=[
280+
swf[
281+
(sim.spike_times[i] >= sim.position_time[test_start_idx])
282+
& (sim.spike_times[i] < sim.position_time[test_end_idx])
283+
]
284+
for i, swf in enumerate(sim.spike_waveform_features)
285+
],
286+
time=sim.position_time[test_start_idx:test_end_idx],
287+
position=sim.position[test_start_idx:test_end_idx],
288+
position_time=sim.position_time[test_start_idx:test_end_idx],
289+
)
290+
291+
# Check posterior sums to 1 across spatial dimension (state_bins)
292+
posterior_sums = results.acausal_posterior.sum(dim="state_bins")
293+
assert np.allclose(posterior_sums.values, 1.0, atol=1e-10)
294+
295+
@settings(deadline=5000, max_examples=10)
296+
@given(st.integers(min_value=42, max_value=9999))
297+
def test_posteriors_nonnegative_and_bounded(self, seed):
298+
"""Property: decoder posteriors are in [0, 1]."""
299+
# NOTE: n_runs must be >= 3 to create proper 2D position data
300+
# NOTE: Need substantial data for RandomWalk to build proper position bins
301+
sim = make_simulated_run_data(
302+
n_tetrodes=2,
303+
place_field_means=np.arange(0, 80, 20),
304+
sampling_frequency=500,
305+
n_runs=3,
306+
seed=seed,
307+
)
308+
309+
n_encode = int(0.7 * len(sim.position_time))
310+
is_training = np.ones(len(sim.position_time), dtype=bool)
311+
is_training[n_encode:] = False
312+
313+
decoder = ClusterlessDecoder(
314+
clusterless_algorithm="clusterless_kde",
315+
clusterless_algorithm_params={
316+
"position_std": 6.0,
317+
"waveform_std": 24.0,
318+
"block_size": 50,
319+
},
320+
continuous_transition_types=[[RandomWalk(movement_var=25.0)]],
321+
)
322+
323+
decoder.fit(
324+
sim.position_time,
325+
sim.position,
326+
sim.spike_times,
327+
sim.spike_waveform_features,
328+
is_training=is_training,
329+
)
330+
331+
test_start_idx = n_encode
332+
test_end_idx = min(n_encode + 10, len(sim.position_time))
333+
results = decoder.predict(
334+
spike_times=[
335+
st[
336+
(st >= sim.position_time[test_start_idx])
337+
& (st < sim.position_time[test_end_idx])
338+
]
339+
for st in sim.spike_times
340+
],
341+
spike_waveform_features=[
342+
swf[
343+
(sim.spike_times[i] >= sim.position_time[test_start_idx])
344+
& (sim.spike_times[i] < sim.position_time[test_end_idx])
345+
]
346+
for i, swf in enumerate(sim.spike_waveform_features)
347+
],
348+
time=sim.position_time[test_start_idx:test_end_idx],
349+
position=sim.position[test_start_idx:test_end_idx],
350+
position_time=sim.position_time[test_start_idx:test_end_idx],
351+
)
352+
353+
# Check all values in [0, 1]
354+
assert np.all(results.acausal_posterior.values >= 0.0)
355+
assert np.all(results.acausal_posterior.values <= 1.0)
356+
357+
@settings(deadline=5000, max_examples=10)
358+
@given(st.integers(min_value=42, max_value=9999))
359+
def test_log_probabilities_finite(self, seed):
360+
"""Property: log probabilities should be finite (or -inf for zero prob)."""
361+
# NOTE: n_runs must be >= 3 to create proper 2D position data
362+
# NOTE: Need substantial data for RandomWalk to build proper position bins
363+
sim = make_simulated_run_data(
364+
n_tetrodes=2,
365+
place_field_means=np.arange(0, 80, 20),
366+
sampling_frequency=500,
367+
n_runs=3,
368+
seed=seed,
369+
)
370+
371+
n_encode = int(0.7 * len(sim.position_time))
372+
is_training = np.ones(len(sim.position_time), dtype=bool)
373+
is_training[n_encode:] = False
374+
375+
decoder = ClusterlessDecoder(
376+
clusterless_algorithm="clusterless_kde",
377+
clusterless_algorithm_params={
378+
"position_std": 6.0,
379+
"waveform_std": 24.0,
380+
"block_size": 50,
381+
},
382+
continuous_transition_types=[[RandomWalk(movement_var=25.0)]],
383+
)
384+
385+
decoder.fit(
386+
sim.position_time,
387+
sim.position,
388+
sim.spike_times,
389+
sim.spike_waveform_features,
390+
is_training=is_training,
391+
)
392+
393+
test_start_idx = n_encode
394+
test_end_idx = min(n_encode + 10, len(sim.position_time))
395+
results = decoder.predict(
396+
spike_times=[
397+
st[
398+
(st >= sim.position_time[test_start_idx])
399+
& (st < sim.position_time[test_end_idx])
400+
]
401+
for st in sim.spike_times
402+
],
403+
spike_waveform_features=[
404+
swf[
405+
(sim.spike_times[i] >= sim.position_time[test_start_idx])
406+
& (sim.spike_times[i] < sim.position_time[test_end_idx])
407+
]
408+
for i, swf in enumerate(sim.spike_waveform_features)
409+
],
410+
time=sim.position_time[test_start_idx:test_end_idx],
411+
position=sim.position[test_start_idx:test_end_idx],
412+
position_time=sim.position_time[test_start_idx:test_end_idx],
413+
)
414+
415+
# Take log of posteriors
416+
log_posterior = np.log(
417+
results.acausal_posterior.values + 1e-300
418+
) # Avoid log(0)
419+
420+
# Should not have NaN
421+
assert not np.any(np.isnan(log_posterior))
422+
# Should be finite or -inf
423+
assert np.all(np.isfinite(log_posterior) | np.isneginf(log_posterior))

0 commit comments

Comments
 (0)