@@ -83,8 +83,25 @@ def test_enc_tile_size_equivalence(enc_tile_size):
8383 assert np .all (np .isfinite (result_with_enc_tiling ))
8484
8585
86- def test_enc_tile_size_with_pos_tile_size ():
87- """Test combined encoding and position tiling."""
86+ @pytest .mark .unit
87+ @pytest .mark .parametrize (
88+ "enc_tile_size,pos_tile_size" ,
89+ [
90+ (30 , 15 ), # Both tiling
91+ (25 , 40 ), # pos_tile_size > n_pos (no position tiling)
92+ (150 , 20 ), # enc_tile_size > n_enc (no encoding tiling)
93+ ],
94+ )
95+ def test_enc_tile_size_with_pos_tile_size (enc_tile_size , pos_tile_size ):
96+ """Test combined encoding and position tiling.
97+
98+ Parameters
99+ ----------
100+ enc_tile_size : int
101+ Encoding chunk size
102+ pos_tile_size : int
103+ Position chunk size
104+ """
88105 n_enc_spikes = 100
89106 n_dec_spikes = 15
90107 n_pos_bins = 40
@@ -125,20 +142,71 @@ def test_enc_tile_size_with_pos_tile_size():
125142 mean_rate ,
126143 log_position_distance ,
127144 use_gemm = True ,
128- pos_tile_size = 15 , # Tile positions
129- enc_tile_size = 30 , # Tile encoding spikes
145+ pos_tile_size = pos_tile_size ,
146+ enc_tile_size = enc_tile_size ,
130147 )
131148
132149 # Should match
150+ max_diff = np .max (np .abs (result_baseline - result_both_tiling ))
133151 assert np .allclose (
134- result_baseline , result_both_tiling , rtol = 1e-5 , atol = 1e-8
135- ), f"Max diff: { np . max ( np . abs ( result_baseline - result_both_tiling )) } "
152+ result_baseline , result_both_tiling , rtol = 1e-5 , atol = 1e-7
153+ ), f"enc_tile_size= { enc_tile_size } , pos_tile_size= { pos_tile_size } : Max diff = { max_diff } "
136154
137- print (f"✓ Combined enc_tile_size=30 + pos_tile_size=15 matches baseline" )
138- print (f" Max diff: { np .max (np .abs (result_baseline - result_both_tiling )):.2e} " )
139155
156+ @pytest .mark .unit
157+ def test_enc_tile_size_edge_cases ():
158+ """Test edge cases for encoding tiling."""
159+ n_enc_spikes = 10
160+ n_dec_spikes = 5
161+ n_pos_bins = 8
162+ n_features = 2
163+
164+ np .random .seed (456 )
165+ dec_features = jnp .array (np .random .randn (n_dec_spikes , n_features ) * 5 )
166+ enc_features = jnp .array (np .random .randn (n_enc_spikes , n_features ) * 5 )
167+ waveform_stds = jnp .array ([2.0 ] * n_features )
168+ occupancy = jnp .ones (n_pos_bins ) * 0.05
169+ mean_rate = 2.0
170+
171+ enc_positions = jnp .array (np .random .uniform (0 , 50 , (n_enc_spikes , 1 )))
172+ interior_bins = jnp .array (np .linspace (0 , 50 , n_pos_bins ))[:, None ]
173+ position_std = jnp .array ([3.0 ])
174+ log_position_distance = log_kde_distance (interior_bins , enc_positions , position_std )
175+
176+ # Baseline
177+ result_baseline = estimate_log_joint_mark_intensity (
178+ dec_features ,
179+ enc_features ,
180+ waveform_stds ,
181+ occupancy ,
182+ mean_rate ,
183+ log_position_distance ,
184+ use_gemm = True ,
185+ enc_tile_size = None ,
186+ )
187+
188+ # Test: enc_tile_size = 1 (smallest possible)
189+ result_tile1 = estimate_log_joint_mark_intensity (
190+ dec_features ,
191+ enc_features ,
192+ waveform_stds ,
193+ occupancy ,
194+ mean_rate ,
195+ log_position_distance ,
196+ use_gemm = True ,
197+ enc_tile_size = 1 ,
198+ )
199+ assert np .allclose (result_baseline , result_tile1 , rtol = 1e-5 , atol = 1e-7 )
140200
141- if __name__ == "__main__" :
142- test_enc_tile_size_equivalence ()
143- test_enc_tile_size_with_pos_tile_size ()
144- print ("\n ✅ All enc_tile_size tests passed!" )
201+ # Test: enc_tile_size = n_enc (no chunking)
202+ result_tile_full = estimate_log_joint_mark_intensity (
203+ dec_features ,
204+ enc_features ,
205+ waveform_stds ,
206+ occupancy ,
207+ mean_rate ,
208+ log_position_distance ,
209+ use_gemm = True ,
210+ enc_tile_size = n_enc_spikes ,
211+ )
212+ assert np .allclose (result_baseline , result_tile_full , rtol = 1e-5 , atol = 1e-7 )
0 commit comments