Skip to content

Commit f5c9735

Browse files
committed
Expand and parametrize encoding tiling tests
Refactors and extends tests for encoding and position tiling by parametrizing tile sizes, adding edge case tests, and improving assertion diagnostics. Also updates comments in clusterless_kde_log.py to clarify JAX compilation context.
1 parent d6134eb commit f5c9735

File tree

2 files changed

+82
-15
lines changed

2 files changed

+82
-15
lines changed

src/non_local_detector/likelihoods/clusterless_kde_log.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -615,9 +615,8 @@ def block_estimate_log_joint_mark_intensity(
615615
if n_decoding_spikes == 0:
616616
return jnp.full((0, n_position_bins), LOG_EPS)
617617

618-
# Use JIT-compiled update with buffer donation for memory efficiency
619-
# Donate the accumulator buffer (arg 0) so it can be reused in-place
620-
@jax.jit
618+
# Use dynamic_update_slice to build output
619+
# Note: JAX will JIT-compile this in the calling context
621620
def _update_block(out_array, block_result, start_idx):
622621
return jax.lax.dynamic_update_slice(out_array, block_result, (start_idx, 0))
623622

src/non_local_detector/tests/likelihoods/test_enc_tile_size.py

Lines changed: 80 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,25 @@ def test_enc_tile_size_equivalence(enc_tile_size):
8383
assert np.all(np.isfinite(result_with_enc_tiling))
8484

8585

86-
def test_enc_tile_size_with_pos_tile_size():
87-
"""Test combined encoding and position tiling."""
86+
@pytest.mark.unit
87+
@pytest.mark.parametrize(
88+
"enc_tile_size,pos_tile_size",
89+
[
90+
(30, 15), # Both tiling
91+
(25, 40), # pos_tile_size > n_pos (no position tiling)
92+
(150, 20), # enc_tile_size > n_enc (no encoding tiling)
93+
],
94+
)
95+
def test_enc_tile_size_with_pos_tile_size(enc_tile_size, pos_tile_size):
96+
"""Test combined encoding and position tiling.
97+
98+
Parameters
99+
----------
100+
enc_tile_size : int
101+
Encoding chunk size
102+
pos_tile_size : int
103+
Position chunk size
104+
"""
88105
n_enc_spikes = 100
89106
n_dec_spikes = 15
90107
n_pos_bins = 40
@@ -125,20 +142,71 @@ def test_enc_tile_size_with_pos_tile_size():
125142
mean_rate,
126143
log_position_distance,
127144
use_gemm=True,
128-
pos_tile_size=15, # Tile positions
129-
enc_tile_size=30, # Tile encoding spikes
145+
pos_tile_size=pos_tile_size,
146+
enc_tile_size=enc_tile_size,
130147
)
131148

132149
# Should match
150+
max_diff = np.max(np.abs(result_baseline - result_both_tiling))
133151
assert np.allclose(
134-
result_baseline, result_both_tiling, rtol=1e-5, atol=1e-8
135-
), f"Max diff: {np.max(np.abs(result_baseline - result_both_tiling))}"
152+
result_baseline, result_both_tiling, rtol=1e-5, atol=1e-7
153+
), f"enc_tile_size={enc_tile_size}, pos_tile_size={pos_tile_size}: Max diff = {max_diff}"
136154

137-
print(f"✓ Combined enc_tile_size=30 + pos_tile_size=15 matches baseline")
138-
print(f" Max diff: {np.max(np.abs(result_baseline - result_both_tiling)):.2e}")
139155

156+
@pytest.mark.unit
157+
def test_enc_tile_size_edge_cases():
158+
"""Test edge cases for encoding tiling."""
159+
n_enc_spikes = 10
160+
n_dec_spikes = 5
161+
n_pos_bins = 8
162+
n_features = 2
163+
164+
np.random.seed(456)
165+
dec_features = jnp.array(np.random.randn(n_dec_spikes, n_features) * 5)
166+
enc_features = jnp.array(np.random.randn(n_enc_spikes, n_features) * 5)
167+
waveform_stds = jnp.array([2.0] * n_features)
168+
occupancy = jnp.ones(n_pos_bins) * 0.05
169+
mean_rate = 2.0
170+
171+
enc_positions = jnp.array(np.random.uniform(0, 50, (n_enc_spikes, 1)))
172+
interior_bins = jnp.array(np.linspace(0, 50, n_pos_bins))[:, None]
173+
position_std = jnp.array([3.0])
174+
log_position_distance = log_kde_distance(interior_bins, enc_positions, position_std)
175+
176+
# Baseline
177+
result_baseline = estimate_log_joint_mark_intensity(
178+
dec_features,
179+
enc_features,
180+
waveform_stds,
181+
occupancy,
182+
mean_rate,
183+
log_position_distance,
184+
use_gemm=True,
185+
enc_tile_size=None,
186+
)
187+
188+
# Test: enc_tile_size = 1 (smallest possible)
189+
result_tile1 = estimate_log_joint_mark_intensity(
190+
dec_features,
191+
enc_features,
192+
waveform_stds,
193+
occupancy,
194+
mean_rate,
195+
log_position_distance,
196+
use_gemm=True,
197+
enc_tile_size=1,
198+
)
199+
assert np.allclose(result_baseline, result_tile1, rtol=1e-5, atol=1e-7)
140200

141-
if __name__ == "__main__":
142-
test_enc_tile_size_equivalence()
143-
test_enc_tile_size_with_pos_tile_size()
144-
print("\n✅ All enc_tile_size tests passed!")
201+
# Test: enc_tile_size = n_enc (no chunking)
202+
result_tile_full = estimate_log_joint_mark_intensity(
203+
dec_features,
204+
enc_features,
205+
waveform_stds,
206+
occupancy,
207+
mean_rate,
208+
log_position_distance,
209+
use_gemm=True,
210+
enc_tile_size=n_enc_spikes,
211+
)
212+
assert np.allclose(result_baseline, result_tile_full, rtol=1e-5, atol=1e-7)

0 commit comments

Comments
 (0)