Skip to content

Commit fff0aa4

Browse files
m9hclaude
andcommitted
Add real-data training pipeline, GPU kernels, educational materials
Training pipeline: - Rewrite trainer for numerical stability: log-scale FR loss, NaN-safe gradients, per-element weight clamping, differentiable burst proxy - Add --from-recording flag to train_culture.py for loading NWB/HDF5 recordings and extracting targets automatically - Fix activity window detection for late-starting recordings (Sharf) - Fix integrator scan dtype mismatch for JAX 0.9 (bool→float32 spikes) GPU kernel development: - Add CSC event-driven v2 (flat gather, no 2D waste) to pallas_ops.py - Add Pallas GPU kernel for CSC synaptic input with atomic_add - Add benchmark suite comparing BCOO, CSC v1/v2, and Pallas - Finding: BCOO/cuSPARSE is fastest at all scales tested (5-15ms at 10K-100K neurons); event-driven approaches have too much JAX dispatch overhead to beat it Real-data validation: - Add validate_real_data.py for DANDI + Sharf dataset analysis - Add analyze_all_datasets.py for multi-dataset batch processing - Fix Maxwell HDF5 loader for compound spikeTimes datasets (Sharf 2022) - Successfully processed 60 DANDI + 33 Sharf recordings (86 GB) Educational materials: - Add notebooks/02_real_vs_simulated.ipynb (DANDI vs BL-1 comparison) - Add docs/slides/bl1_overview.tex (10-slide Beamer presentation with TikZ architecture diagrams and pseudocode) Infrastructure: - Add scripts/nsg_submit.py for NSG/SDSC Expanse GPU job submission via CIPRES REST API (GPU_PY_EXPANSE tool) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 6cbb28f commit fff0aa4

10 files changed

Lines changed: 2018 additions & 215 deletions

File tree

docs/slides/bl1_overview.tex

Lines changed: 448 additions & 0 deletions
Large diffs are not rendered by default.

notebooks/02_real_vs_simulated.ipynb

Lines changed: 121 additions & 0 deletions
Large diffs are not rendered by default.

scripts/analyze_all_datasets.py

Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
#!/usr/bin/env python3
2+
"""Analyze all downloaded datasets and produce a comprehensive report.
3+
4+
Processes:
5+
- DANDI 001611: rat cortical HD-MEA (2700 NWB files, 12 subjects)
6+
- Sharf 2022: human brain organoid HD-MEA (33 HDF5 files)
7+
8+
Outputs per-recording statistics + aggregate summaries + comparison to BL-1.
9+
"""
10+
11+
import json
12+
import math
13+
import sys
14+
import time
15+
from pathlib import Path
16+
17+
import numpy as np
18+
19+
from bl1.validation.loaders import (
20+
compute_recording_statistics,
21+
load_maxwell_h5,
22+
load_nwb_spike_trains,
23+
spike_trains_to_raster,
24+
)
25+
from bl1.validation.datasets import DATASETS
26+
27+
DATA_DIR = Path("/data/datasets/bl1")
28+
RESULTS_DIR = Path("results/dataset_analysis")
29+
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
30+
31+
32+
def process_nwb(filepath, max_duration_s=60.0):
33+
"""Load NWB, handle sample-index conversion, compute stats."""
34+
data = load_nwb_spike_trains(str(filepath))
35+
if data["n_units"] < 5:
36+
return None
37+
38+
# Auto-detect sample indices
39+
if data["duration_s"] > 86400:
40+
sr = 20000.0
41+
data["spike_times"] = [st / sr for st in data["spike_times"]]
42+
data["duration_s"] /= sr
43+
44+
dur = min(data["duration_s"], max_duration_s)
45+
if dur < 5:
46+
return None
47+
48+
trimmed = {
49+
"spike_times": [st[st <= dur] for st in data["spike_times"]],
50+
"duration_s": dur,
51+
"n_units": data["n_units"],
52+
}
53+
stats = compute_recording_statistics(trimmed, dt_ms=0.5, burst_threshold_std=1.5)
54+
stats["n_units"] = data["n_units"]
55+
stats["duration_s"] = dur
56+
stats["full_duration_s"] = data["duration_s"]
57+
return stats
58+
59+
60+
def process_maxwell(filepath, max_duration_s=60.0):
61+
"""Load Maxwell HDF5, compute stats."""
62+
data = load_maxwell_h5(str(filepath))
63+
if data["n_units"] < 5:
64+
return None
65+
66+
# Find actual start of spiking activity (some files have late-starting data)
67+
all_times = [st for st in data["spike_times"] if len(st) > 0]
68+
if not all_times:
69+
return None
70+
t_min = float(min(st.min() for st in all_times))
71+
t_max = float(max(st.max() for st in all_times))
72+
actual_dur = t_max - t_min
73+
if actual_dur < 5:
74+
return None
75+
76+
# Trim window relative to start of activity
77+
window_end = t_min + min(actual_dur, max_duration_s)
78+
trimmed_times = []
79+
for st in data["spike_times"]:
80+
mask = (st >= t_min) & (st <= window_end)
81+
trimmed_times.append(st[mask] - t_min) # shift to t=0
82+
83+
use_dur = min(actual_dur, max_duration_s)
84+
trimmed = {
85+
"spike_times": trimmed_times,
86+
"duration_s": use_dur,
87+
"n_units": data["n_units"],
88+
}
89+
stats = compute_recording_statistics(trimmed, dt_ms=0.5, burst_threshold_std=1.5)
90+
stats["n_units"] = data["n_units"]
91+
stats["duration_s"] = use_dur
92+
stats["actual_duration_s"] = actual_dur
93+
stats["sampling_rate"] = data.get("sampling_rate", 20000.0)
94+
return stats
95+
96+
97+
def summarize(records, label):
98+
"""Print aggregate stats for a collection of recordings."""
99+
if not records:
100+
print(f" No valid recordings for {label}")
101+
return {}
102+
103+
fr = [r["mean_firing_rate_hz"] for r in records]
104+
br = [r["burst_rate_per_min"] for r in records]
105+
dur = [r.get("burst_duration_mean_ms", float("nan")) for r in records]
106+
dur = [d for d in dur if not math.isnan(d)]
107+
units = [r["n_units"] for r in records]
108+
109+
summary = {
110+
"n_recordings": len(records),
111+
"n_units_range": [int(min(units)), int(max(units))],
112+
"firing_rate_hz": {"mean": np.mean(fr), "std": np.std(fr),
113+
"min": np.min(fr), "max": np.max(fr)},
114+
"burst_rate_per_min": {"mean": np.mean(br), "std": np.std(br),
115+
"min": np.min(br), "max": np.max(br)},
116+
}
117+
if dur:
118+
summary["burst_duration_ms"] = {"mean": np.mean(dur), "std": np.std(dur),
119+
"min": np.min(dur), "max": np.max(dur)}
120+
121+
print(f"\n {label}: {len(records)} recordings")
122+
print(f" Units: {min(units)} - {max(units)}")
123+
print(f" FR (Hz): {np.mean(fr):.2f} +/- {np.std(fr):.2f} [{np.min(fr):.2f} - {np.max(fr):.2f}]")
124+
print(f" Burst/min: {np.mean(br):.1f} +/- {np.std(br):.1f} [{np.min(br):.1f} - {np.max(br):.1f}]")
125+
if dur:
126+
print(f" Burst dur: {np.mean(dur):.0f} +/- {np.std(dur):.0f} ms [{np.min(dur):.0f} - {np.max(dur):.0f}]")
127+
return summary
128+
129+
130+
def main():
131+
t0 = time.time()
132+
print("=" * 78)
133+
print(" BL-1 Multi-Dataset Analysis")
134+
print("=" * 78)
135+
136+
all_summaries = {}
137+
138+
# -----------------------------------------------------------------------
139+
# 1. DANDI 001611 — sample across subjects (10 files per subject)
140+
# -----------------------------------------------------------------------
141+
print("\n[1] DANDI 001611: Rat cortical HD-MEA")
142+
dandi_dir = DATA_DIR / "dandi_001611_rat_cortical" / "001611"
143+
dandi_records = []
144+
if dandi_dir.exists():
145+
subjects = sorted([d for d in dandi_dir.iterdir() if d.is_dir() and d.name.startswith("sub-")])
146+
print(f" Subjects: {len(subjects)}")
147+
for subj in subjects:
148+
nwb_files = sorted(subj.glob("*.nwb"))
149+
# Sample up to 5 files per subject for speed
150+
sample = nwb_files[:5]
151+
for f in sample:
152+
try:
153+
stats = process_nwb(f)
154+
if stats:
155+
stats["subject"] = subj.name
156+
stats["filename"] = f.name
157+
dandi_records.append(stats)
158+
sys.stdout.write(".")
159+
sys.stdout.flush()
160+
except Exception as e:
161+
sys.stdout.write("x")
162+
sys.stdout.flush()
163+
print()
164+
all_summaries["dandi_001611"] = summarize(dandi_records, "DANDI 001611 (rat cortical)")
165+
else:
166+
print(" Not found")
167+
168+
# -----------------------------------------------------------------------
169+
# 2. Sharf 2022 — ALL organoid files (33 total)
170+
# -----------------------------------------------------------------------
171+
print("\n[2] Sharf 2022: Human brain organoid HD-MEA")
172+
sharf_dir = DATA_DIR / "zenodo_sharf_2022"
173+
sharf_records = []
174+
sharf_dev = []
175+
sharf_drug = []
176+
sharf_baseline = []
177+
178+
if sharf_dir.exists():
179+
h5_files = sorted(sharf_dir.glob("*.h5"))
180+
print(f" Files: {len(h5_files)}")
181+
for f in h5_files:
182+
try:
183+
stats = process_maxwell(f)
184+
if stats:
185+
stats["filename"] = f.name
186+
sharf_records.append(stats)
187+
# Categorize
188+
if f.name.startswith("Development_"):
189+
sharf_dev.append(stats)
190+
elif f.name.startswith("Drug_"):
191+
sharf_drug.append(stats)
192+
else:
193+
sharf_baseline.append(stats)
194+
sys.stdout.write(".")
195+
sys.stdout.flush()
196+
except Exception as e:
197+
sys.stdout.write("x")
198+
sys.stdout.flush()
199+
print()
200+
all_summaries["sharf_2022_all"] = summarize(sharf_records, "Sharf 2022 (all)")
201+
if sharf_baseline:
202+
all_summaries["sharf_2022_baseline"] = summarize(sharf_baseline, "Sharf 2022 (7-month baseline)")
203+
if sharf_dev:
204+
all_summaries["sharf_2022_development"] = summarize(sharf_dev, "Sharf 2022 (development series)")
205+
if sharf_drug:
206+
all_summaries["sharf_2022_drug"] = summarize(sharf_drug, "Sharf 2022 (drug dose-response)")
207+
208+
# -----------------------------------------------------------------------
209+
# 3. Print comparison table
210+
# -----------------------------------------------------------------------
211+
print("\n" + "=" * 78)
212+
print(" Cross-Dataset Comparison")
213+
print("=" * 78)
214+
print(f"\n {'Dataset':<35s} {'N':>4s} {'FR (Hz)':>12s} {'Burst/min':>12s} {'Units':>10s}")
215+
print(f" {'-'*35} {'-'*4} {'-'*12} {'-'*12} {'-'*10}")
216+
217+
for name, s in all_summaries.items():
218+
if not s:
219+
continue
220+
fr = s["firing_rate_hz"]
221+
br = s["burst_rate_per_min"]
222+
u = s["n_units_range"]
223+
print(f" {name:<35s} {s['n_recordings']:4d} "
224+
f"{fr['mean']:5.1f}+/-{fr['std']:4.1f} "
225+
f"{br['mean']:5.1f}+/-{br['std']:4.1f} "
226+
f"{u[0]:4d}-{u[1]:4d}")
227+
228+
# Wagenaar reference
229+
w = DATASETS["wagenaar_2006"]
230+
print(f" {'Wagenaar 2006 (reference)':<35s} {'59':>4s} "
231+
f"{'0.1-5.0':>12s} {'0.2-20':>12s} {'60ch':>10s}")
232+
233+
# -----------------------------------------------------------------------
234+
# 4. Save detailed results
235+
# -----------------------------------------------------------------------
236+
# Convert numpy to native Python for JSON
237+
def to_native(obj):
238+
if isinstance(obj, (np.floating, np.integer)):
239+
return obj.item()
240+
if isinstance(obj, np.ndarray):
241+
return obj.tolist()
242+
if isinstance(obj, dict):
243+
return {k: to_native(v) for k, v in obj.items()}
244+
if isinstance(obj, list):
245+
return [to_native(v) for v in obj]
246+
return obj
247+
248+
out = {
249+
"summaries": to_native(all_summaries),
250+
"dandi_records": to_native(dandi_records),
251+
"sharf_records": to_native(sharf_records),
252+
}
253+
out_path = RESULTS_DIR / "dataset_analysis.json"
254+
with open(out_path, "w") as f:
255+
json.dump(out, f, indent=2, default=str)
256+
257+
elapsed = time.time() - t0
258+
print(f"\n Total time: {elapsed:.0f}s")
259+
print(f" Results saved to: {out_path}")
260+
print("=" * 78)
261+
262+
263+
if __name__ == "__main__":
264+
main()

0 commit comments

Comments
 (0)