Skip to content

Commit 6367144

Browse files
edenoclaude
andcommitted
feat: implement log-space computation for local KDE likelihood
Replaced linear-space marginal density computation with pure log-space to prevent underflow when dealing with high-dimensional features and extreme feature distances. Changes: - Use block_log_kde() instead of block_kde() + safe_log() - Compute spike contribution as: log(rate) + log(density) - log(occupancy) - Add tests demonstrating improved numerical stability Numerical validation: - Property tests: 23/23 passed - Snapshot tests: 39/39 passed (no changes) - Golden regression: Same results as baseline - Full test suite: 499 passed (+3 new), 44 failed (pre-existing) Benefits: - Prevents underflow in extreme cases (features shifted 100+ std devs) - Maintains exact behavior for normal cases (< 1e-14 difference) - Improves accuracy: log-space computes -45k correctly vs -34.54 clamping 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
1 parent ad83072 commit 6367144

File tree

2 files changed

+299
-12
lines changed

2 files changed

+299
-12
lines changed

src/non_local_detector/likelihoods/clusterless_kde_log.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
LOG_EPS,
1111
KDEModel,
1212
block_kde,
13+
block_log_kde,
1314
gaussian_pdf,
1415
get_position_at_time,
1516
get_spike_time_bin_ind,
@@ -1300,7 +1301,8 @@ def compute_local_log_likelihood(
13001301
position_at_spike_time = all_spike_positions[electrode_idx]
13011302
occupancy_at_spike_time = all_occupancies[start_idx:end_idx]
13021303

1303-
marginal_density = block_kde(
1304+
# Compute marginal density in log-space for numerical stability
1305+
log_marginal_density = block_log_kde(
13041306
eval_points=jnp.concatenate(
13051307
(
13061308
position_at_spike_time,
@@ -1319,18 +1321,16 @@ def compute_local_log_likelihood(
13191321
block_size=block_size,
13201322
)
13211323

1322-
# Use safe_log to avoid -inf from zero marginal_density or occupancy
1323-
# The where still protects against division by zero occupancy
1324+
# Compute spike contribution in log-space:
1325+
# log(rate * density / occupancy) = log(rate) + log(density) - log(occupancy)
1326+
log_mean_rate = jnp.log(electrode_mean_rate)
1327+
log_occupancy = safe_log(occupancy_at_spike_time, eps=EPS)
1328+
1329+
# Spike contribution: sum over spikes in each time bin
1330+
spike_contribution = log_mean_rate + log_marginal_density - log_occupancy
1331+
13241332
log_likelihood += jax.ops.segment_sum(
1325-
safe_log(
1326-
electrode_mean_rate
1327-
* jnp.where(
1328-
occupancy_at_spike_time > 0.0,
1329-
marginal_density / occupancy_at_spike_time,
1330-
EPS, # Use EPS instead of 0 to avoid log(0)
1331-
),
1332-
eps=EPS,
1333-
),
1333+
spike_contribution,
13341334
get_spike_time_bin_ind(electrode_spike_times, time),
13351335
indices_are_sorted=True,
13361336
num_segments=n_time,
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
"""Test numerical stability of log-space local KDE likelihood computation.
2+
3+
This test verifies that computing local likelihood entirely in log-space
4+
prevents underflow issues that can occur when multiplying many small Gaussian
5+
values in linear space before taking the log.
6+
"""
7+
8+
import jax
9+
import jax.numpy as jnp
10+
import numpy as np
11+
import pytest
12+
13+
from non_local_detector.environment import Environment
14+
from non_local_detector.likelihoods.clusterless_kde_log import (
15+
compute_local_log_likelihood,
16+
fit_clusterless_kde_encoding_model,
17+
)
18+
from non_local_detector.likelihoods.common import block_kde, block_log_kde, safe_log, EPS
19+
20+
21+
@pytest.fixture
22+
def extreme_feature_data():
23+
"""Generate synthetic data with extreme waveform features.
24+
25+
This creates conditions where linear-space computation would underflow:
26+
- High dimensional waveform features (8 dimensions)
27+
- Large feature distances (features far from encoding samples)
28+
- Small standard deviations (narrow Gaussians)
29+
30+
These conditions make the product of per-dimension Gaussians extremely small,
31+
causing underflow in linear space but handled correctly in log-space.
32+
"""
33+
np.random.seed(123)
34+
35+
# Time and position
36+
n_time_pos = 500
37+
position_time = np.linspace(0, 5, n_time_pos)
38+
position = np.linspace(0, 50, n_time_pos)[:, None] # 1D position
39+
40+
# Spike data for 2 electrodes
41+
n_electrodes = 2
42+
spike_times = []
43+
spike_waveform_features = []
44+
45+
for i in range(n_electrodes):
46+
# Encoding spikes (training data) - moderate features
47+
n_enc_spikes = 40
48+
enc_spike_times = np.sort(
49+
np.random.uniform(position_time[0], position_time[-1], n_enc_spikes)
50+
)
51+
# Features centered around 0
52+
enc_features = np.random.randn(n_enc_spikes, 8) * 5.0
53+
54+
# Decoding spikes (test data) - EXTREME features far from encoding
55+
n_dec_spikes = 20
56+
dec_spike_times = np.sort(np.random.uniform(2.0, 3.0, n_dec_spikes))
57+
# Features shifted far away from encoding data
58+
# This creates very small Gaussian values that underflow in linear space
59+
dec_features = np.random.randn(n_dec_spikes, 8) * 5.0 + 50.0 # Shifted by 50!
60+
61+
# Combine encoding and decoding for storage
62+
all_times = np.concatenate([enc_spike_times, dec_spike_times])
63+
all_features = np.concatenate([enc_features, dec_features], axis=0)
64+
65+
spike_times.append(all_times)
66+
spike_waveform_features.append(all_features)
67+
68+
return {
69+
"position_time": position_time,
70+
"position": position,
71+
"spike_times": spike_times,
72+
"spike_waveform_features": spike_waveform_features,
73+
"dec_time_range": (2.0, 3.0), # Time range where we have extreme features
74+
}
75+
76+
77+
@pytest.fixture
78+
def simple_1d_environment():
79+
"""Create a simple 1D environment for testing."""
80+
env = Environment(
81+
environment_name="line", place_bin_size=5.0, position_range=((0.0, 50.0),)
82+
)
83+
# Fit with dummy position
84+
dummy_pos = np.linspace(0.0, 50.0, 11)[:, None]
85+
env = env.fit_place_grid(position=dummy_pos, infer_track_interior=False)
86+
return env
87+
88+
89+
def test_local_likelihood_log_space_prevents_underflow(
90+
extreme_feature_data, simple_1d_environment
91+
):
92+
"""Test that log-space local likelihood handles extreme features without underflow.
93+
94+
This test verifies that when decoding spikes have waveform features very far
95+
from encoding samples (causing underflow in linear-space computation), the
96+
log-space implementation still produces finite, reasonable results.
97+
98+
The test checks that:
99+
1. All likelihood values are finite (no -inf from underflow)
100+
2. The expected-counts term dominates (since spike term should be very negative)
101+
3. The result is numerically stable across different random seeds
102+
"""
103+
data = extreme_feature_data
104+
env = simple_1d_environment
105+
106+
# Fit encoding model
107+
enc_model = fit_clusterless_kde_encoding_model(
108+
position_time=data["position_time"],
109+
position=data["position"],
110+
spike_times=data["spike_times"],
111+
spike_waveform_features=data["spike_waveform_features"],
112+
environment=env,
113+
position_std=6.0,
114+
waveform_std=3.0, # Small std → narrow Gaussians → more underflow risk
115+
block_size=50,
116+
disable_progress_bar=True,
117+
)
118+
119+
# Decode on time window with extreme features
120+
dec_start, dec_end = data["dec_time_range"]
121+
time = np.linspace(dec_start, dec_end, 10)
122+
123+
# Compute local likelihood
124+
ll_local = compute_local_log_likelihood(
125+
time=time,
126+
position_time=data["position_time"],
127+
position=data["position"],
128+
spike_times=data["spike_times"],
129+
spike_waveform_features=data["spike_waveform_features"],
130+
occupancy_model=enc_model["occupancy_model"],
131+
gpi_models=enc_model["gpi_models"],
132+
encoding_spike_waveform_features=enc_model["encoding_spike_waveform_features"],
133+
encoding_positions=enc_model["encoding_positions"],
134+
environment=env,
135+
mean_rates=jnp.array(enc_model["mean_rates"]),
136+
position_std=enc_model["position_std"],
137+
waveform_std=enc_model["waveform_std"],
138+
block_size=50,
139+
disable_progress_bar=True,
140+
)
141+
142+
# 1. All values should be finite (no -inf from underflow)
143+
assert np.all(np.isfinite(ll_local)), (
144+
f"Local likelihood contains non-finite values: "
145+
f"min={np.min(ll_local)}, max={np.max(ll_local)}, "
146+
f"n_inf={np.sum(np.isinf(ll_local))}, n_nan={np.sum(np.isnan(ll_local))}"
147+
)
148+
149+
# 2. Values should be reasonable (negative, since log-likelihood)
150+
# With extreme features far from encoding data, spike contributions should be
151+
# very negative, so expected-counts term dominates
152+
assert np.all(ll_local < 0), "Log-likelihood should be negative"
153+
154+
# 3. Should not saturate at LOG_EPS (which would indicate underflow)
155+
from non_local_detector.likelihoods.common import LOG_EPS
156+
157+
# If underflow occurred, many values would equal LOG_EPS
158+
n_at_log_eps = np.sum(np.isclose(ll_local, LOG_EPS, rtol=0, atol=1e-10))
159+
assert n_at_log_eps == 0, (
160+
f"Log-likelihood saturated at LOG_EPS in {n_at_log_eps}/{ll_local.size} values, "
161+
f"indicating underflow in linear-space computation"
162+
)
163+
164+
# 4. Verify shape
165+
assert ll_local.shape == (len(time), 1), (
166+
f"Expected shape ({len(time)}, 1), got {ll_local.shape}"
167+
)
168+
169+
170+
def test_local_likelihood_log_space_moderate_features(
171+
extreme_feature_data, simple_1d_environment
172+
):
173+
"""Test that log-space local likelihood works correctly with moderate features.
174+
175+
This is a sanity check that the log-space implementation doesn't break
176+
normal cases where linear-space would have worked fine.
177+
"""
178+
data = extreme_feature_data
179+
env = simple_1d_environment
180+
181+
# Modify data to have moderate features (not extreme)
182+
moderate_spike_features = [
183+
np.random.randn(len(times), 8) * 5.0 # No shift, moderate scale
184+
for times in data["spike_times"]
185+
]
186+
187+
# Fit encoding model
188+
enc_model = fit_clusterless_kde_encoding_model(
189+
position_time=data["position_time"],
190+
position=data["position"],
191+
spike_times=data["spike_times"],
192+
spike_waveform_features=moderate_spike_features,
193+
environment=env,
194+
position_std=6.0,
195+
waveform_std=10.0, # Larger std → less risk of underflow
196+
block_size=50,
197+
disable_progress_bar=True,
198+
)
199+
200+
# Decode on time window
201+
time = np.linspace(1.0, 2.0, 10)
202+
203+
# Compute local likelihood
204+
ll_local = compute_local_log_likelihood(
205+
time=time,
206+
position_time=data["position_time"],
207+
position=data["position"],
208+
spike_times=data["spike_times"],
209+
spike_waveform_features=moderate_spike_features,
210+
occupancy_model=enc_model["occupancy_model"],
211+
gpi_models=enc_model["gpi_models"],
212+
encoding_spike_waveform_features=enc_model["encoding_spike_waveform_features"],
213+
encoding_positions=enc_model["encoding_positions"],
214+
environment=env,
215+
mean_rates=jnp.array(enc_model["mean_rates"]),
216+
position_std=enc_model["position_std"],
217+
waveform_std=enc_model["waveform_std"],
218+
block_size=50,
219+
disable_progress_bar=True,
220+
)
221+
222+
# All values should be finite and reasonable
223+
assert np.all(np.isfinite(ll_local)), "Local likelihood should be finite"
224+
assert np.all(ll_local < 0), "Log-likelihood should be negative"
225+
assert ll_local.shape == (len(time), 1), "Shape should be correct"
226+
227+
228+
def test_block_log_kde_vs_log_block_kde():
229+
"""Test that block_log_kde is more accurate than safe_log(block_kde).
230+
231+
This directly tests the numerical difference between:
232+
1. Linear-space: marginal = block_kde(...); log_marginal = safe_log(marginal)
233+
2. Log-space: log_marginal = block_log_kde(...)
234+
235+
With extreme feature distances, (1) suffers from underflow and produces
236+
inaccurate results (clamped to LOG_EPS), while (2) computes accurately.
237+
"""
238+
np.random.seed(456)
239+
240+
# Create test data with VERY extreme feature distances
241+
n_eval = 10
242+
n_samples = 20
243+
n_dims = 10 # High dimensional
244+
245+
# Evaluation points FAR from samples
246+
eval_points = np.random.randn(n_eval, n_dims) * 5.0 + 100.0 # Shifted by 100!
247+
248+
# Samples centered at origin
249+
samples = np.random.randn(n_samples, n_dims) * 5.0 # Near zero
250+
251+
# Small standard deviations → very narrow Gaussians → underflow in linear space
252+
std = np.ones(n_dims) * 1.0
253+
254+
# Method 1: Linear-space then log (current implementation)
255+
marginal_linear = block_kde(eval_points, samples, std, block_size=5)
256+
log_marginal_linear = safe_log(marginal_linear, eps=EPS)
257+
258+
# Method 2: Pure log-space (proposed implementation)
259+
log_marginal_log = block_log_kde(eval_points, samples, std, block_size=5)
260+
261+
# Check that linear-space method has underflow (values exactly at LOG_EPS)
262+
from non_local_detector.likelihoods.common import LOG_EPS
263+
264+
# With extreme distances, marginal_linear should underflow to exactly 0
265+
assert np.all(marginal_linear == 0), (
266+
f"Expected all marginal_linear values to underflow to 0, "
267+
f"but got min={np.min(marginal_linear)}, max={np.max(marginal_linear)}"
268+
)
269+
270+
# Which safe_log clamps to exactly LOG_EPS
271+
assert np.all(log_marginal_linear == LOG_EPS), (
272+
f"Expected all log_marginal_linear values to be clamped to LOG_EPS={LOG_EPS}, "
273+
f"but got values: {np.unique(log_marginal_linear)}"
274+
)
275+
276+
# Log-space method should compute much more negative (accurate) values
277+
assert np.all(log_marginal_log < LOG_EPS), (
278+
f"Log-space values should be more negative than LOG_EPS={LOG_EPS}, "
279+
f"but got min={np.min(log_marginal_log)}, max={np.max(log_marginal_log)}"
280+
)
281+
282+
# The difference should be substantial (thousands in log-space)
283+
min_diff = np.min(log_marginal_log - LOG_EPS)
284+
assert min_diff < -1000, (
285+
f"Expected large difference (< -1000) between log-space and clamped linear-space, "
286+
f"but got min_diff={min_diff:.2f}"
287+
)

0 commit comments

Comments
 (0)