Skip to content

Commit 04d0a4d

Browse files
committed
Refactor imports and clean up print statements
Standardized and cleaned up print statements across scripts for consistency. Refactored and reordered imports in several scripts and test files, removing unused imports and fixing import order. Minor code cleanups include removing unused variables, correcting loop variable names, and fixing assignment statements in tests.
1 parent d9cd2aa commit 04d0a4d

16 files changed

+66
-62
lines changed

scripts/compare_kde_implementations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def main():
5656
n_runs = 10
5757
dimensions = [2, 4, 6, 8, 10]
5858

59-
print(f"Configuration:")
59+
print("Configuration:")
6060
print(f" Evaluation points: {n_eval}")
6161
print(f" Training samples: {n_samples}")
6262
print(f" Runs per test: {n_runs}")

scripts/compare_memory_usage.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def analyze_array_allocations(n_eval, n_samples, n_dims):
128128
print(f" Per-dimension output: {n_samples} × {n_eval} × 8 bytes = {output_size/1024:.2f} KB")
129129
vmap_intermediates = n_dims * n_samples * n_eval * 8 # All dimensions at once
130130
print(f" All intermediate arrays: {n_dims} × {output_size/1024:.2f} KB = {vmap_intermediates/1024:.2f} KB")
131-
print(f" XLA fusion may reduce this through operation fusion")
131+
print(" XLA fusion may reduce this through operation fusion")
132132
optimized_peak = vmap_intermediates + output_size
133133
print(f" **Estimated peak (worst case): {optimized_peak/1024:.2f} KB**")
134134
print(f" **Estimated peak (with fusion): {(vmap_intermediates*0.5 + output_size)/1024:.2f} KB**")
@@ -141,7 +141,7 @@ def analyze_array_allocations(n_eval, n_samples, n_dims):
141141

142142
if ratio > 1.5:
143143
print(f"⚠️ Optimized version may use {ratio:.1f}x more memory in worst case")
144-
print(f" But XLA fusion likely reduces this significantly")
144+
print(" But XLA fusion likely reduces this significantly")
145145
elif ratio > 0.7:
146146
print(f"✓ Similar memory footprint ({ratio:.2f}x)")
147147
else:
@@ -222,7 +222,7 @@ def main():
222222
elif mem_ratio < 0.8:
223223
print(f"✓ Optimized uses {1/mem_ratio:.1f}x less memory")
224224
else:
225-
print(f"✓ Similar memory usage")
225+
print("✓ Similar memory usage")
226226

227227
# Estimate device memory
228228
print("\nEstimated JAX Device Memory:")
@@ -293,7 +293,7 @@ def main():
293293
print(" for all dimensions simultaneously (parallel execution),")
294294
print(" while the original processes one dimension at a time.")
295295
print()
296-
print(" **Trade-off**: Speed (10x faster) vs Memory (~{:.1f}x more)".format(avg_ratio))
296+
print(f" **Trade-off**: Speed (10x faster) vs Memory (~{avg_ratio:.1f}x more)")
297297
elif avg_ratio > 1.2:
298298
print(f"⚠️ Optimized version uses slightly more memory ({avg_ratio:.2f}x)")
299299
else:
@@ -310,7 +310,7 @@ def main():
310310
print("**Memory-Speed Trade-off**")
311311
print()
312312
print("The optimized version (clusterless_kde_log.py):")
313-
print(f" ✅ Speed: 10.8x faster")
313+
print(" ✅ Speed: 10.8x faster")
314314
print(f" ⚠️ Memory: ~{avg_ratio:.1f}x more usage")
315315
print()
316316
print("**Use optimized version when:**")

scripts/investigate_extreme_values.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def analyze_extreme_features():
6969

7070
feature_distances = np.array(feature_distances)
7171

72-
print(f"Feature space distances (standardized by std):")
72+
print("Feature space distances (standardized by std):")
7373
print(f" Min: {feature_distances.min():.2f} std")
7474
print(f" Max: {feature_distances.max():.2f} std")
7575
print(f" Mean: {feature_distances.mean():.2f} std")
@@ -91,7 +91,7 @@ def analyze_extreme_features():
9191

9292
# Compute position distances
9393
position_distance = kde_distance(interior_bins, enc_positions, position_std)
94-
log_position_distance = log_kde_distance(interior_bins, enc_positions, position_std)
94+
log_kde_distance(interior_bins, enc_positions, position_std)
9595

9696
print("-" * 80)
9797
print("ORIGINAL IMPLEMENTATION (clusterless_kde.py)")
@@ -118,14 +118,14 @@ def analyze_extreme_features():
118118
n_inf = np.sum(np.isinf(ll_original))
119119
n_nan = np.sum(np.isnan(ll_original))
120120

121-
print(f"Value distribution:")
121+
print("Value distribution:")
122122
print(f" Finite values: {n_finite}/{ll_original.size} ({100*n_finite/ll_original.size:.1f}%)")
123123
print(f" -Inf values: {n_inf}/{ll_original.size} ({100*n_inf/ll_original.size:.1f}%)")
124124
print(f" NaN values: {n_nan}/{ll_original.size} ({100*n_nan/ll_original.size:.1f}%)")
125125
print()
126126

127127
if n_finite > 0:
128-
print(f"Finite value statistics:")
128+
print("Finite value statistics:")
129129
print(f" Min: {ll_original[finite_mask].min():.4f}")
130130
print(f" Max: {ll_original[finite_mask].max():.4f}")
131131
print(f" Mean: {ll_original[finite_mask].mean():.4f}")
@@ -167,7 +167,7 @@ def analyze_extreme_features():
167167
n_inf_log = np.sum(np.isinf(ll_log_no_gemm))
168168
n_nan_log = np.sum(np.isnan(ll_log_no_gemm))
169169

170-
print(f"Value distribution:")
170+
print("Value distribution:")
171171
print(f" Finite values: {n_finite_log}/{ll_log_no_gemm.size} ({100*n_finite_log/ll_log_no_gemm.size:.1f}%)")
172172
print(f" -Inf values: {n_inf_log}/{ll_log_no_gemm.size} ({100*n_inf_log/ll_log_no_gemm.size:.1f}%)")
173173
print(f" NaN values: {n_nan_log}/{ll_log_no_gemm.size} ({100*n_nan_log/ll_log_no_gemm.size:.1f}%)")
@@ -211,14 +211,14 @@ def analyze_extreme_features():
211211
n_inf_gemm = np.sum(np.isinf(ll_log_gemm))
212212
n_nan_gemm = np.sum(np.isnan(ll_log_gemm))
213213

214-
print(f"Value distribution:")
214+
print("Value distribution:")
215215
print(f" Finite values: {n_finite_gemm}/{ll_log_gemm.size} ({100*n_finite_gemm/ll_log_gemm.size:.1f}%)")
216216
print(f" -Inf values: {n_inf_gemm}/{ll_log_gemm.size} ({100*n_inf_gemm/ll_log_gemm.size:.1f}%)")
217217
print(f" NaN values: {n_nan_gemm}/{ll_log_gemm.size} ({100*n_nan_gemm/ll_log_gemm.size:.1f}%)")
218218
print()
219219

220220
if n_finite_gemm > 0:
221-
print(f"Finite value statistics:")
221+
print("Finite value statistics:")
222222
print(f" Min: {ll_log_gemm[finite_mask_gemm].min():.4f}")
223223
print(f" Max: {ll_log_gemm[finite_mask_gemm].max():.4f}")
224224
print(f" Mean: {ll_log_gemm[finite_mask_gemm].mean():.4f}")
@@ -254,7 +254,7 @@ def analyze_extreme_features():
254254

255255
if n_finite_gemm > n_finite:
256256
improvement = n_finite_gemm - n_finite
257-
print(f"✅ GEMM optimization improves numerical stability:")
257+
print("✅ GEMM optimization improves numerical stability:")
258258
print(f" {improvement} more finite values ({100*improvement/ll_original.size:.1f}% of total)")
259259
print()
260260
print(" The GEMM approach computes in log-space throughout,")
@@ -289,7 +289,7 @@ def analyze_extreme_features():
289289
if n_finite_gemm > n_finite * 1.1: # At least 10% improvement
290290
print("**Use GEMM optimization for extreme features**")
291291
print()
292-
print(f"The GEMM approach (use_gemm=True) significantly improves")
292+
print("The GEMM approach (use_gemm=True) significantly improves")
293293
print(f"numerical stability, preserving {100*n_finite_gemm/ll_log_gemm.size:.1f}% finite values")
294294
print(f"vs {100*n_finite/ll_original.size:.1f}% for the original.")
295295
print()
@@ -300,7 +300,7 @@ def analyze_extreme_features():
300300
else:
301301
print("**GEMM optimization does not significantly help with extreme features**")
302302
print()
303-
print(f"Both approaches produce similar amounts of underflow")
303+
print("Both approaches produce similar amounts of underflow")
304304
print(f"({100*n_finite/ll_original.size:.1f}% vs {100*n_finite_gemm/ll_log_gemm.size:.1f}% finite).")
305305
print()
306306
print("With such extreme feature distances, underflow is unavoidable.")

scripts/profile_feature_dimensions.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,22 @@
66

77
import sys
88
import time
9+
910
import numpy as np
10-
import jax.numpy as jnp
1111

1212
sys.path.insert(0, "src")
1313

14+
from non_local_detector.environment import Environment
1415
from non_local_detector.likelihoods.clusterless_kde import (
1516
fit_clusterless_kde_encoding_model,
1617
predict_clusterless_kde_log_likelihood,
1718
)
1819
from non_local_detector.likelihoods.clusterless_kde_log import (
1920
fit_clusterless_kde_encoding_model as fit_log,
21+
)
22+
from non_local_detector.likelihoods.clusterless_kde_log import (
2023
predict_clusterless_kde_log_likelihood as predict_log,
2124
)
22-
from non_local_detector.environment import Environment
2325

2426

2527
def create_test_data(n_features, n_encoding=200, n_decoding=100, n_positions=500):
@@ -189,7 +191,7 @@ def profile_dimension(n_features):
189191
speedup_gemm = mean_ref / mean_gemm
190192
speedup_no_gemm = mean_ref / mean_no_gemm
191193

192-
print(f"\nSpeedup vs reference:")
194+
print("\nSpeedup vs reference:")
193195
print(f" GEMM (vmap): {speedup_gemm:.2f}x")
194196
print(f" No GEMM (log): {speedup_no_gemm:.2f}x")
195197

scripts/profile_feature_dimensions_simple.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66

77
import sys
88
import time
9-
import numpy as np
9+
1010
import jax.numpy as jnp
11+
import numpy as np
1112

1213
sys.path.insert(0, "src")
1314

@@ -103,7 +104,7 @@ def profile_dimension(n_features, n_enc=200, n_dec=100, n_pos=500):
103104
speedup_no_gemm = mean_ref / mean_no_gemm
104105
gemm_vs_no_gemm = mean_no_gemm / mean_gemm
105106

106-
print(f"\nSpeedup vs reference:")
107+
print("\nSpeedup vs reference:")
107108
print(f" GEMM (vmap): {speedup_gemm:.2f}x")
108109
print(f" No GEMM: {speedup_no_gemm:.2f}x")
109110
print(f"\nGEMM vs No-GEMM: {gemm_vs_no_gemm:.2f}x")

scripts/profile_log_kde_optimization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def main():
7979
n_runs = 10
8080
dimensions = [2, 4, 6, 8, 10]
8181

82-
print(f"Configuration:")
82+
print("Configuration:")
8383
print(f" Evaluation points: {n_eval}")
8484
print(f" Training samples: {n_samples}")
8585
print(f" Runs per test: {n_runs}")

scripts/profile_optimized_kde.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22

33
import sys
44
import time
5-
import numpy as np
5+
66
import jax.numpy as jnp
7+
import numpy as np
78

89
sys.path.insert(0, "src")
910

1011
from non_local_detector.likelihoods.clusterless_kde import (
1112
estimate_log_joint_mark_intensity,
12-
estimate_log_joint_mark_intensity_vectorized,
1313
estimate_log_joint_mark_intensity_logspace,
14+
estimate_log_joint_mark_intensity_vectorized,
1415
)
1516

1617

@@ -167,13 +168,13 @@ def main():
167168
print(f" Average speedup: {avg_speedup:.2f}x")
168169

169170
if avg_speedup >= 2.0:
170-
print(f" ✓ Excellent! Achieved target of 2-4x speedup")
171+
print(" ✓ Excellent! Achieved target of 2-4x speedup")
171172
elif avg_speedup >= 1.5:
172-
print(f" ✓ Good! Significant performance improvement")
173+
print(" ✓ Good! Significant performance improvement")
173174
elif avg_speedup >= 1.2:
174-
print(f" ~ Moderate improvement")
175+
print(" ~ Moderate improvement")
175176
else:
176-
print(f" ✗ Minimal improvement")
177+
print(" ✗ Minimal improvement")
177178

178179
print("\nLog-space + Vectorized + JIT Optimization:")
179180
log_speedups = [r[2]["logspace"][2] for r in all_results]
@@ -184,13 +185,13 @@ def main():
184185
print(f" Average speedup: {avg_speedup:.2f}x")
185186

186187
if avg_speedup >= 2.0:
187-
print(f" ✓ Excellent! Achieved target of 2-4x speedup")
188+
print(" ✓ Excellent! Achieved target of 2-4x speedup")
188189
elif avg_speedup >= 1.5:
189-
print(f" ✓ Good! Significant performance improvement")
190+
print(" ✓ Good! Significant performance improvement")
190191
elif avg_speedup >= 1.2:
191-
print(f" ~ Moderate improvement")
192+
print(" ~ Moderate improvement")
192193
else:
193-
print(f" ✗ Minimal improvement")
194+
print(" ✗ Minimal improvement")
194195

195196
# Comparison: Vectorized vs Log-space
196197
print("\nLog-space vs Vectorized:")
@@ -212,22 +213,22 @@ def main():
212213
print("\nBest implementation:")
213214
if log_avg > vec_avg * 1.1:
214215
print(f" → Log-space + Vectorized + JIT ({log_avg:.2f}x average speedup)")
215-
print(f" Use: estimate_log_joint_mark_intensity_logspace()")
216+
print(" Use: estimate_log_joint_mark_intensity_logspace()")
216217
elif vec_avg > 1.5:
217218
print(f" → Vectorized + JIT ({vec_avg:.2f}x average speedup)")
218-
print(f" Use: estimate_log_joint_mark_intensity_vectorized()")
219+
print(" Use: estimate_log_joint_mark_intensity_vectorized()")
219220
else:
220-
print(f" → Original implementation (optimizations not beneficial)")
221-
print(f" Use: estimate_log_joint_mark_intensity()")
221+
print(" → Original implementation (optimizations not beneficial)")
222+
print(" Use: estimate_log_joint_mark_intensity()")
222223

223224
print("\nFor production use:")
224225
if max(vec_avg, log_avg) >= 2.0:
225-
print(f" ✓ Optimization successful - recommend deploying optimized version")
226-
print(f" ✓ Numerical equivalence verified (max diff < 1e-6)")
226+
print(" ✓ Optimization successful - recommend deploying optimized version")
227+
print(" ✓ Numerical equivalence verified (max diff < 1e-6)")
227228
print(f" ✓ Average speedup: {max(vec_avg, log_avg):.2f}x")
228229
else:
229-
print(f" ~ Optimization provides moderate benefit")
230-
print(f" ~ Consider for performance-critical applications only")
230+
print(" ~ Optimization provides moderate benefit")
231+
print(" ~ Consider for performance-critical applications only")
231232

232233
print("\n" + "="*70)
233234

scripts/test_optimized_kde.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
"""Test numerical equivalence of optimized KDE implementations."""
22

33
import sys
4-
import numpy as np
4+
55
import jax.numpy as jnp
6+
import numpy as np
67

78
sys.path.insert(0, "src")
89

910
from non_local_detector.likelihoods.clusterless_kde import (
11+
estimate_log_joint_mark_intensity,
12+
estimate_log_joint_mark_intensity_logspace,
13+
estimate_log_joint_mark_intensity_vectorized,
1014
kde_distance,
1115
kde_distance_vectorized,
1216
log_kde_distance,
13-
estimate_log_joint_mark_intensity,
14-
estimate_log_joint_mark_intensity_vectorized,
15-
estimate_log_joint_mark_intensity_logspace,
1617
)
1718

1819

@@ -131,12 +132,12 @@ def test_estimate_functions_equivalence():
131132
# Check with more appropriate tolerance for numerical differences
132133
is_close = jnp.allclose(result_original, result_vectorized, rtol=1e-5, atol=1e-6)
133134
if not is_close:
134-
print(f" WARNING: Differences exceed tolerance")
135+
print(" WARNING: Differences exceed tolerance")
135136
print(f" Max absolute diff: {max_diff_vec}")
136137
print(f" Max relative diff: {rel_diff_vec}")
137138
# Check if it's just due to float precision
138139
if max_diff_vec < 1e-5 and rel_diff_vec < 1e-4:
139-
print(f" Differences are within acceptable float32 precision, continuing...")
140+
print(" Differences are within acceptable float32 precision, continuing...")
140141
else:
141142
raise AssertionError(f"Vectorized version not equivalent for {n_features}D")
142143

@@ -153,12 +154,12 @@ def test_estimate_functions_equivalence():
153154
# Check with more appropriate tolerance
154155
is_close = jnp.allclose(result_original, result_logspace, rtol=1e-5, atol=1e-6)
155156
if not is_close:
156-
print(f" WARNING: Differences exceed tolerance")
157+
print(" WARNING: Differences exceed tolerance")
157158
print(f" Max absolute diff: {max_diff_log}")
158159
print(f" Max relative diff: {rel_diff_log}")
159160
# Check if it's just due to float precision
160161
if max_diff_log < 1e-5 and rel_diff_log < 1e-4:
161-
print(f" Differences are within acceptable float32 precision, continuing...")
162+
print(" Differences are within acceptable float32 precision, continuing...")
162163
else:
163164
raise AssertionError(f"Log-space version not equivalent for {n_features}D")
164165

src/non_local_detector/likelihoods/clusterless_kde_log.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
EPS,
1010
LOG_EPS,
1111
KDEModel,
12-
block_kde,
1312
block_log_kde,
1413
gaussian_pdf,
1514
get_position_at_time,

src/non_local_detector/tests/integration/test_clusterless_kde_parity.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def test_estimate_intensity_moderate_features():
120120
np.random.seed(42)
121121
dec_features = jnp.array(np.random.randn(n_dec_spikes, n_features) * 10 + 50)
122122
enc_features = jnp.array(np.random.randn(n_enc_spikes, n_features) * 10 + 50)
123-
enc_weights = jnp.ones(n_enc_spikes)
123+
jnp.ones(n_enc_spikes)
124124
waveform_stds = jnp.array([5.0] * n_features)
125125
occupancy = jnp.ones(n_pos_bins) * 0.1
126126
mean_rate = 5.0
@@ -185,7 +185,7 @@ def test_estimate_intensity_extreme_features():
185185
np.random.seed(42)
186186
dec_features = jnp.array(np.random.randn(n_dec_spikes, n_features) * 50 + 100)
187187
enc_features = jnp.array(np.random.randn(n_enc_spikes, n_features) * 50 + 200)
188-
enc_weights = jnp.ones(n_enc_spikes)
188+
jnp.ones(n_enc_spikes)
189189
waveform_stds = jnp.array([10.0] * n_features)
190190
occupancy = jnp.ones(n_pos_bins) * 0.1
191191
mean_rate = 2.0

0 commit comments

Comments
 (0)