|
| 1 | +#!/usr/bin/env python3 |
| 2 | +"""Analyze all downloaded datasets and produce a comprehensive report. |
| 3 | +
|
| 4 | +Processes: |
| 5 | + - DANDI 001611: rat cortical HD-MEA (2700 NWB files, 12 subjects) |
| 6 | + - Sharf 2022: human brain organoid HD-MEA (33 HDF5 files) |
| 7 | +
|
| 8 | +Outputs per-recording statistics + aggregate summaries + comparison to BL-1. |
| 9 | +""" |
| 10 | + |
| 11 | +import json |
| 12 | +import math |
| 13 | +import sys |
| 14 | +import time |
| 15 | +from pathlib import Path |
| 16 | + |
| 17 | +import numpy as np |
| 18 | + |
| 19 | +from bl1.validation.loaders import ( |
| 20 | + compute_recording_statistics, |
| 21 | + load_maxwell_h5, |
| 22 | + load_nwb_spike_trains, |
| 23 | + spike_trains_to_raster, |
| 24 | +) |
| 25 | +from bl1.validation.datasets import DATASETS |
| 26 | + |
| 27 | +DATA_DIR = Path("/data/datasets/bl1") |
| 28 | +RESULTS_DIR = Path("results/dataset_analysis") |
| 29 | +RESULTS_DIR.mkdir(parents=True, exist_ok=True) |
| 30 | + |
| 31 | + |
| 32 | +def process_nwb(filepath, max_duration_s=60.0): |
| 33 | + """Load NWB, handle sample-index conversion, compute stats.""" |
| 34 | + data = load_nwb_spike_trains(str(filepath)) |
| 35 | + if data["n_units"] < 5: |
| 36 | + return None |
| 37 | + |
| 38 | + # Auto-detect sample indices |
| 39 | + if data["duration_s"] > 86400: |
| 40 | + sr = 20000.0 |
| 41 | + data["spike_times"] = [st / sr for st in data["spike_times"]] |
| 42 | + data["duration_s"] /= sr |
| 43 | + |
| 44 | + dur = min(data["duration_s"], max_duration_s) |
| 45 | + if dur < 5: |
| 46 | + return None |
| 47 | + |
| 48 | + trimmed = { |
| 49 | + "spike_times": [st[st <= dur] for st in data["spike_times"]], |
| 50 | + "duration_s": dur, |
| 51 | + "n_units": data["n_units"], |
| 52 | + } |
| 53 | + stats = compute_recording_statistics(trimmed, dt_ms=0.5, burst_threshold_std=1.5) |
| 54 | + stats["n_units"] = data["n_units"] |
| 55 | + stats["duration_s"] = dur |
| 56 | + stats["full_duration_s"] = data["duration_s"] |
| 57 | + return stats |
| 58 | + |
| 59 | + |
| 60 | +def process_maxwell(filepath, max_duration_s=60.0): |
| 61 | + """Load Maxwell HDF5, compute stats.""" |
| 62 | + data = load_maxwell_h5(str(filepath)) |
| 63 | + if data["n_units"] < 5: |
| 64 | + return None |
| 65 | + |
| 66 | + # Find actual start of spiking activity (some files have late-starting data) |
| 67 | + all_times = [st for st in data["spike_times"] if len(st) > 0] |
| 68 | + if not all_times: |
| 69 | + return None |
| 70 | + t_min = float(min(st.min() for st in all_times)) |
| 71 | + t_max = float(max(st.max() for st in all_times)) |
| 72 | + actual_dur = t_max - t_min |
| 73 | + if actual_dur < 5: |
| 74 | + return None |
| 75 | + |
| 76 | + # Trim window relative to start of activity |
| 77 | + window_end = t_min + min(actual_dur, max_duration_s) |
| 78 | + trimmed_times = [] |
| 79 | + for st in data["spike_times"]: |
| 80 | + mask = (st >= t_min) & (st <= window_end) |
| 81 | + trimmed_times.append(st[mask] - t_min) # shift to t=0 |
| 82 | + |
| 83 | + use_dur = min(actual_dur, max_duration_s) |
| 84 | + trimmed = { |
| 85 | + "spike_times": trimmed_times, |
| 86 | + "duration_s": use_dur, |
| 87 | + "n_units": data["n_units"], |
| 88 | + } |
| 89 | + stats = compute_recording_statistics(trimmed, dt_ms=0.5, burst_threshold_std=1.5) |
| 90 | + stats["n_units"] = data["n_units"] |
| 91 | + stats["duration_s"] = use_dur |
| 92 | + stats["actual_duration_s"] = actual_dur |
| 93 | + stats["sampling_rate"] = data.get("sampling_rate", 20000.0) |
| 94 | + return stats |
| 95 | + |
| 96 | + |
| 97 | +def summarize(records, label): |
| 98 | + """Print aggregate stats for a collection of recordings.""" |
| 99 | + if not records: |
| 100 | + print(f" No valid recordings for {label}") |
| 101 | + return {} |
| 102 | + |
| 103 | + fr = [r["mean_firing_rate_hz"] for r in records] |
| 104 | + br = [r["burst_rate_per_min"] for r in records] |
| 105 | + dur = [r.get("burst_duration_mean_ms", float("nan")) for r in records] |
| 106 | + dur = [d for d in dur if not math.isnan(d)] |
| 107 | + units = [r["n_units"] for r in records] |
| 108 | + |
| 109 | + summary = { |
| 110 | + "n_recordings": len(records), |
| 111 | + "n_units_range": [int(min(units)), int(max(units))], |
| 112 | + "firing_rate_hz": {"mean": np.mean(fr), "std": np.std(fr), |
| 113 | + "min": np.min(fr), "max": np.max(fr)}, |
| 114 | + "burst_rate_per_min": {"mean": np.mean(br), "std": np.std(br), |
| 115 | + "min": np.min(br), "max": np.max(br)}, |
| 116 | + } |
| 117 | + if dur: |
| 118 | + summary["burst_duration_ms"] = {"mean": np.mean(dur), "std": np.std(dur), |
| 119 | + "min": np.min(dur), "max": np.max(dur)} |
| 120 | + |
| 121 | + print(f"\n {label}: {len(records)} recordings") |
| 122 | + print(f" Units: {min(units)} - {max(units)}") |
| 123 | + print(f" FR (Hz): {np.mean(fr):.2f} +/- {np.std(fr):.2f} [{np.min(fr):.2f} - {np.max(fr):.2f}]") |
| 124 | + print(f" Burst/min: {np.mean(br):.1f} +/- {np.std(br):.1f} [{np.min(br):.1f} - {np.max(br):.1f}]") |
| 125 | + if dur: |
| 126 | + print(f" Burst dur: {np.mean(dur):.0f} +/- {np.std(dur):.0f} ms [{np.min(dur):.0f} - {np.max(dur):.0f}]") |
| 127 | + return summary |
| 128 | + |
| 129 | + |
| 130 | +def main(): |
| 131 | + t0 = time.time() |
| 132 | + print("=" * 78) |
| 133 | + print(" BL-1 Multi-Dataset Analysis") |
| 134 | + print("=" * 78) |
| 135 | + |
| 136 | + all_summaries = {} |
| 137 | + |
| 138 | + # ----------------------------------------------------------------------- |
| 139 | + # 1. DANDI 001611 — sample across subjects (10 files per subject) |
| 140 | + # ----------------------------------------------------------------------- |
| 141 | + print("\n[1] DANDI 001611: Rat cortical HD-MEA") |
| 142 | + dandi_dir = DATA_DIR / "dandi_001611_rat_cortical" / "001611" |
| 143 | + dandi_records = [] |
| 144 | + if dandi_dir.exists(): |
| 145 | + subjects = sorted([d for d in dandi_dir.iterdir() if d.is_dir() and d.name.startswith("sub-")]) |
| 146 | + print(f" Subjects: {len(subjects)}") |
| 147 | + for subj in subjects: |
| 148 | + nwb_files = sorted(subj.glob("*.nwb")) |
| 149 | + # Sample up to 5 files per subject for speed |
| 150 | + sample = nwb_files[:5] |
| 151 | + for f in sample: |
| 152 | + try: |
| 153 | + stats = process_nwb(f) |
| 154 | + if stats: |
| 155 | + stats["subject"] = subj.name |
| 156 | + stats["filename"] = f.name |
| 157 | + dandi_records.append(stats) |
| 158 | + sys.stdout.write(".") |
| 159 | + sys.stdout.flush() |
| 160 | + except Exception as e: |
| 161 | + sys.stdout.write("x") |
| 162 | + sys.stdout.flush() |
| 163 | + print() |
| 164 | + all_summaries["dandi_001611"] = summarize(dandi_records, "DANDI 001611 (rat cortical)") |
| 165 | + else: |
| 166 | + print(" Not found") |
| 167 | + |
| 168 | + # ----------------------------------------------------------------------- |
| 169 | + # 2. Sharf 2022 — ALL organoid files (33 total) |
| 170 | + # ----------------------------------------------------------------------- |
| 171 | + print("\n[2] Sharf 2022: Human brain organoid HD-MEA") |
| 172 | + sharf_dir = DATA_DIR / "zenodo_sharf_2022" |
| 173 | + sharf_records = [] |
| 174 | + sharf_dev = [] |
| 175 | + sharf_drug = [] |
| 176 | + sharf_baseline = [] |
| 177 | + |
| 178 | + if sharf_dir.exists(): |
| 179 | + h5_files = sorted(sharf_dir.glob("*.h5")) |
| 180 | + print(f" Files: {len(h5_files)}") |
| 181 | + for f in h5_files: |
| 182 | + try: |
| 183 | + stats = process_maxwell(f) |
| 184 | + if stats: |
| 185 | + stats["filename"] = f.name |
| 186 | + sharf_records.append(stats) |
| 187 | + # Categorize |
| 188 | + if f.name.startswith("Development_"): |
| 189 | + sharf_dev.append(stats) |
| 190 | + elif f.name.startswith("Drug_"): |
| 191 | + sharf_drug.append(stats) |
| 192 | + else: |
| 193 | + sharf_baseline.append(stats) |
| 194 | + sys.stdout.write(".") |
| 195 | + sys.stdout.flush() |
| 196 | + except Exception as e: |
| 197 | + sys.stdout.write("x") |
| 198 | + sys.stdout.flush() |
| 199 | + print() |
| 200 | + all_summaries["sharf_2022_all"] = summarize(sharf_records, "Sharf 2022 (all)") |
| 201 | + if sharf_baseline: |
| 202 | + all_summaries["sharf_2022_baseline"] = summarize(sharf_baseline, "Sharf 2022 (7-month baseline)") |
| 203 | + if sharf_dev: |
| 204 | + all_summaries["sharf_2022_development"] = summarize(sharf_dev, "Sharf 2022 (development series)") |
| 205 | + if sharf_drug: |
| 206 | + all_summaries["sharf_2022_drug"] = summarize(sharf_drug, "Sharf 2022 (drug dose-response)") |
| 207 | + |
| 208 | + # ----------------------------------------------------------------------- |
| 209 | + # 3. Print comparison table |
| 210 | + # ----------------------------------------------------------------------- |
| 211 | + print("\n" + "=" * 78) |
| 212 | + print(" Cross-Dataset Comparison") |
| 213 | + print("=" * 78) |
| 214 | + print(f"\n {'Dataset':<35s} {'N':>4s} {'FR (Hz)':>12s} {'Burst/min':>12s} {'Units':>10s}") |
| 215 | + print(f" {'-'*35} {'-'*4} {'-'*12} {'-'*12} {'-'*10}") |
| 216 | + |
| 217 | + for name, s in all_summaries.items(): |
| 218 | + if not s: |
| 219 | + continue |
| 220 | + fr = s["firing_rate_hz"] |
| 221 | + br = s["burst_rate_per_min"] |
| 222 | + u = s["n_units_range"] |
| 223 | + print(f" {name:<35s} {s['n_recordings']:4d} " |
| 224 | + f"{fr['mean']:5.1f}+/-{fr['std']:4.1f} " |
| 225 | + f"{br['mean']:5.1f}+/-{br['std']:4.1f} " |
| 226 | + f"{u[0]:4d}-{u[1]:4d}") |
| 227 | + |
| 228 | + # Wagenaar reference |
| 229 | + w = DATASETS["wagenaar_2006"] |
| 230 | + print(f" {'Wagenaar 2006 (reference)':<35s} {'59':>4s} " |
| 231 | + f"{'0.1-5.0':>12s} {'0.2-20':>12s} {'60ch':>10s}") |
| 232 | + |
| 233 | + # ----------------------------------------------------------------------- |
| 234 | + # 4. Save detailed results |
| 235 | + # ----------------------------------------------------------------------- |
| 236 | + # Convert numpy to native Python for JSON |
| 237 | + def to_native(obj): |
| 238 | + if isinstance(obj, (np.floating, np.integer)): |
| 239 | + return obj.item() |
| 240 | + if isinstance(obj, np.ndarray): |
| 241 | + return obj.tolist() |
| 242 | + if isinstance(obj, dict): |
| 243 | + return {k: to_native(v) for k, v in obj.items()} |
| 244 | + if isinstance(obj, list): |
| 245 | + return [to_native(v) for v in obj] |
| 246 | + return obj |
| 247 | + |
| 248 | + out = { |
| 249 | + "summaries": to_native(all_summaries), |
| 250 | + "dandi_records": to_native(dandi_records), |
| 251 | + "sharf_records": to_native(sharf_records), |
| 252 | + } |
| 253 | + out_path = RESULTS_DIR / "dataset_analysis.json" |
| 254 | + with open(out_path, "w") as f: |
| 255 | + json.dump(out, f, indent=2, default=str) |
| 256 | + |
| 257 | + elapsed = time.time() - t0 |
| 258 | + print(f"\n Total time: {elapsed:.0f}s") |
| 259 | + print(f" Results saved to: {out_path}") |
| 260 | + print("=" * 78) |
| 261 | + |
| 262 | + |
| 263 | +if __name__ == "__main__": |
| 264 | + main() |
0 commit comments