日期:2026-02-21(2026-03-01 修订) 项目:NeuroHorizon(基于 POYO/POYO+ fork) 代码库:/root/autodl-tmp/NeuroHorizon
POYO(Population-level Your Own, NeurIPS 2023)是一个基于 Transformer 的神经群体解码框架,用于从电生理记录(neural spike data)解码行为输出(如光标速度、手部运动)。
核心特点:
- 支持任意神经和行为模态
- 多 session、多 recording 训练能力
- 按需数据加载(lazy loading HDF5)
- 三个模型变体:POYO、POYOPlus(多任务)、CaPOYO(钙成像)
注:
poyo_mp是一个工厂函数(factory function),返回使用预设 1.3M 参数配置的 POYO 实例,而非独立的模型类。
NeuroHorizon/
├── torch_brain/ # 主包
│ ├── models/ # 模型定义
│ │ ├── poyo.py # POYO 基础模型
│ │ ├── poyo_plus.py # POYO+ 多任务扩展
│ │ └── capoyo.py # CaPOYO 钙成像变体
│ ├── nn/ # 神经网络模块
│ │ ├── rotary_attention.py # 旋转位置编码注意力(RotaryCrossAttention, RotarySelfAttention)
│ │ ├── position_embeddings.py # 时间位置编码(SinusoidalTimeEmbedding, RotaryTimeEmbedding)
│ │ ├── infinite_vocab_embedding.py # 动态词汇嵌入
│ │ ├── multitask_readout.py # 多任务读出层 + prepare_for_multitask_readout()
│ │ ├── embedding.py # 标准嵌入
│ │ ├── feedforward.py # GEGLU FFN 层
│ │ └── loss.py # 损失函数(MSE, CE, Mallow)
│ ├── data/ # 数据加载与采样
│ │ ├── collate.py # 批处理整理(pad8, chain, track_mask8)
│ │ └── sampler.py # 采样器(5种,见§4.2)
│ ├── dataset/ # 数据集 API
│ │ ├── dataset.py # Dataset 类(HDF5加载)
│ │ ├── mixins.py # Dataset mixins
│ │ └── nested.py # NestedSpikingDataset(多数据集组合,命名空间机制)
│ ├── transforms/ # 数据增强
│ │ ├── container.py # Compose, ConditionalChoice, RandomChoice 组合器
│ │ ├── unit_dropout.py # 随机丢弃神经元
│ │ ├── unit_filter.py # 按条件/正则过滤神经元(UnitFilter, UnitFilterById)
│ │ ├── output_sampler.py # 随机采样输出 token(RandomOutputSampler)
│ │ ├── random_crop.py # 随机时间裁剪
│ │ └── random_time_scaling.py # 时间拉伸增强
│ ├── utils/ # 工具函数
│ │ ├── tokenizers.py # token 创建(start/end/latent tokens)
│ │ ├── readout.py # 输出准备(归一化、加权)
│ │ ├── stitcher.py # 预测拼接(重叠窗口合并)
│ │ ├── binning.py # 时间 binning
│ │ └── weights.py # 区间加权
│ ├── registry.py # 模态注册系统
│ └── optim.py # SparseLamb 优化器
├── examples/
│ ├── poyo/ # POYO 训练示例
│ │ ├── train.py # 训练脚本
│ │ ├── configs/ # Hydra 配置文件
│ │ └── datasets/ # 数据集实现
│ └── poyo_plus/ # POYO+ 训练示例
│ ├── train.py # 训练脚本
│ └── configs/ # 配置文件
├── tests/ # 单元测试
├── docs/ # 文档
└── pyproject.toml # 包配置与依赖
关键补充说明:
RotaryTimeEmbedding定义在nn/position_embeddings.py中,被rotary_attention.py引用NestedSpikingDataset通过命名空间将多个 Dataset 组合在一起,recording_id 变为"<dataset_name>/<recording_id>"形式,是 POYO+ 多数据集训练的关键
POYO+ 采用经典的 编码器-处理器-解码器 Transformer 架构:
输入 Spike 序列
↓
[嵌入层] Unit Embedding + Token Type Embedding + Rotary Time Embedding
↓
[编码器] Cross-Attention (spikes → latents) + FFN
↓
[处理层] (6-24 层) Self-Attention + FFN
↓
[解码器] Cross-Attention (latents → outputs) + FFN
↓
[读出层] 任务特定的线性投影
↓
输出预测 (如 2D 光标速度)
| 组件 | 类名 | 文件 | 功能 |
|---|---|---|---|
| Unit Embedding | InfiniteVocabEmbedding |
nn/infinite_vocab_embedding.py |
动态词汇量的 unit 嵌入,支持 lazy 初始化 |
| Session Embedding | InfiniteVocabEmbedding |
同上 | session 级别嵌入 |
| Token Type Embedding | Embedding |
nn/embedding.py |
区分 3 种 token 类型(DEFAULT=0, START_OF_SEQUENCE=1, END_OF_SEQUENCE=2;嵌入表容量 4,index=3 预留) |
| Latent Embedding | Embedding |
同上 | 可学习的 latent tokens |
| 时间编码 | RotaryTimeEmbedding |
nn/position_embeddings.py |
RoFormer 风格旋转位置编码 |
| Cross-Attention | RotaryCrossAttention |
nn/rotary_attention.py |
带 RoPE 的交叉注意力 |
| Self-Attention | RotarySelfAttention |
nn/rotary_attention.py |
带 RoPE 的自注意力 |
| FFN | FeedForward |
nn/feedforward.py |
GEGLU 激活的前馈网络(见下方说明) |
| 多任务读出 | MultitaskReadout |
nn/multitask_readout.py |
按任务分发的线性读出层 |
GEGLU 激活函数:FeedForward 使用 GEGLU(Gated Gaussian Error Linear Unit),而非标准 GELU/ReLU。输入先通过线性层扩展到 dim * mult * 2(默认 mult=4),然后 chunk 为两半——一半经 GELU 激活作为门控,另一半作为值,两者相乘得到 dim * mult 维输出,再经线性层回到 dim。
rotate_value 参数差异:
- 编码器 cross-attention (
enc_atn):rotate_value=True— value 上也应用旋转编码 - 处理层 self-attention (
proc_layers):rotate_value=True - 解码器 cross-attention (
dec_atn):rotate_value=False— 不对 value 应用旋转
| 配置 | 参数量 | dim | depth | latent_step | num_latents_per_step | cross_heads | self_heads | atn_dropout |
|---|---|---|---|---|---|---|---|---|
| POYO-MP | 约1.3M | 64 | 6 | 0.125 | 16 | 2 | 8 | 0.2 |
| POYO-1 | 约11.8M | 128 | 24 | 0.125 | 32 | 4 | 8 | 0.0 |
基于 torch_brain/models/poyo_plus.py 的 forward() 方法:
def forward(self, *, input_unit_index, input_timestamps, input_token_type, input_mask,
latent_index, latent_timestamps,
output_session_index, output_timestamps, output_decoder_index, ...):
# 1. 输入嵌入
inputs = self.unit_emb(input_unit_index) + self.token_type_emb(input_token_type)
input_timestamp_emb = self.rotary_emb(input_timestamps) # RoPE
# 2. Latent tokens
latents = self.latent_emb(latent_index)
latent_timestamp_emb = self.rotary_emb(latent_timestamps)
# 3. 输出查询
output_queries = self.session_emb(output_session_index) + self.task_emb(output_decoder_index)
output_timestamp_emb = self.rotary_emb(output_timestamps)
# 4. 编码:spikes → latents (Perceiver cross-attention)
latents = latents + self.enc_atn(latents, inputs, latent_timestamp_emb, input_timestamp_emb, input_mask)
latents = latents + self.enc_ffn(latents)
# 5. 处理:多层自注意力
for self_attn, self_ff in self.proc_layers:
latents = latents + self.dropout(self_attn(latents, latent_timestamp_emb))
latents = latents + self.dropout(self_ff(latents))
# 6. 解码:latents → outputs (cross-attention)
output_queries = output_queries + self.dec_atn(output_queries, latents, output_timestamp_emb, latent_timestamp_emb)
output_latents = output_queries + self.dec_ffn(output_queries)
# 7. 多任务读出
output = self.readout(output_embs=output_latents, output_readout_index=output_decoder_index, ...)
return outputPOYO 与 POYOPlus 的关键差异:
- 输出查询构建:POYO 版本中无
task_emb,output_queries 仅由session_emb构成 - 输出层:POYO 使用
nn.Linear直接投影;POYOPlus 使用MultitaskReadout按readout_index分发到不同 Linear 层 - 返回值:POYO 返回
Tensor或List[Tensor](取决于unpack_output);POYOPlus 返回Tuple[List[Dict[str, Tensor]]](每个样本的每个任务的预测字典)
Variable-Length Forward:所有 attention 模块均实现了 forward_varlen() 方法,支持将变长序列 chain 后通过 xformers 的 BlockDiagonalMask 高效处理,减少 padding 计算浪费。
CaPOYO 展示了 POYO 框架如何处理非 spike 的连续值输入(calcium imaging traces),其设计模式对 NeuroHorizon 有参考价值:
关键设计差异(相比 POYOPlus):
- input_value_map:
nn.Linear(1, dim // 2)将标量钙信号值映射到半维度空间 - unit_emb 维度减半:
InfiniteVocabEmbedding(dim // 2)(而非dim) - 拼接而非相加:
cat((input_value_map(values), unit_emb(index)), dim=-1)— 值嵌入和 unit 嵌入拼接为完整维度
对 NeuroHorizon 的启示:若 decoder 中需同时输入 bin 信息和 unit 信息(如 concat(bin_repr, unit_emb)),CaPOYO 的拼接模式是可参考的实现方式。
Dataset 类 (torch_brain/dataset/dataset.py):
- 基于 HDF5 文件的 lazy loading
- 每个 HDF5 文件对应一个 recording session
- 通过
temporaldata.Data对象提供结构化访问 - 支持时间域切片:
data.slice(start_time, end_time) - 提供
get_recording_hook方法,可被子类覆盖用于自定义后处理(如SpikingDatasetMixin.get_recording_hook给 unit_id 加前缀)
数据索引 (DatasetIndex):
- 三元组:
(recording_id, start_time, end_time) - 由 Sampler 生成
| 采样器 | 文件位置 | 用途 |
|---|---|---|
RandomFixedWindowSampler |
data/sampler.py |
训练:从训练区间随机采样固定长度窗口,支持时间抖动增强 |
SequentialFixedWindowSampler |
同上 | 确定性顺序滑动窗口 |
TrialSampler |
同上 | 按 trial 区间采样(对 trial-aligned 预测有参考价值) |
DistributedEvaluationSamplerWrapper |
同上 | 通用分布式评估包装器 |
DistributedStitchingFixedWindowSampler |
同上 | 分布式推理 + 拼接:步长 window_length/2 滑动窗口,配合 DecodingStitchEvaluator 使用 |
torch_brain/data/collate.py 提供的核心函数:
pad8(seq): 将序列填充到 8 的倍数(GPU 效率优化)track_mask8(seq): 生成填充位置的布尔 maskchain(sequences): 将变长序列首尾相连track_batch(sequences): 追踪每个元素所属的 batch index
另有 pad, track_mask, pad2d, track_mask2d 等变体。
POYOPlus 的 tokenize() 方法:
- Spike tokens: 每个 spike (unit_id, timestamp) 成为一个 token
- Start/End tokens: 每个 unit 的时间窗口起止标记
- Latent tokens: 等间距可学习 latent tokens(由
create_linspace_latent_tokens生成) - Output queries: session + task 嵌入,在指定时间戳处查询行为预测
输出查询构建:prepare_for_multitask_readout() 负责从 Data 对象中提取各任务时间戳和值、执行 z-score 归一化、根据 eval_interval 生成评估 mask、分配 readout_index。
关键数据流:
Raw HDF5 → Dataset.__getitem__() → data.slice(start, end) → Transforms → model.tokenize() → Batch
Collation 规范:tokenize 返回的字典需使用 pad8(), chain() 等包装函数标记每个字段的 collation 策略,NeuroHorizon 的 tokenize 必须遵循此规范。
基于 PyTorch Lightning + Hydra 配置管理:
TrainWrapper(LightningModule): 封装模型的训练/验证/测试步骤DataModule(LightningDataModule): 管理数据加载和预处理- Hydra YAML 配置文件控制所有超参数
SparseLamb (torch_brain/optim.py):
- LAMB 优化器的变体,只更新梯度非零的参数
- 特别适用于 InfiniteVocabEmbedding(不是每次都激活所有词汇)
- 参数组(
examples/poyo_plus/train.py):sparse=True:unit_emb + session_emb + readout 参数合并为一组- 标准更新:其余所有参数
- Base LR: 3.125e-5(按 batch size 线性缩放)
- Weight Decay: 1e-4
- Scheduler: OneCycleLR, cosine annealing
div_factor=1:初始 lr = max_lr(无从低到高的 warmup 阶段)pct_start=0.5(通过cfg.optim.lr_decay_start配置):前 50% 步数保持高 lr,后 50% cosine 衰减
- DecodingStitchEvaluator(自定义 Lightning Callback):拼接重叠窗口预测、计算任务特定指标(R²Score)、支持加权 loss、在指定区间上评估(如 reach_period)
- MultiTaskDecodingStitchEvaluator:用于 POYOPlus 的多任务评估
torch_brain/registry.py 提供全局模态注册:
@dataclass
class ModalitySpec:
id: int # 唯一数字 ID
dim: int # 输出维度
type: DataType # CONTINUOUS/BINARY/MULTINOMIAL/MULTILABEL
timestamp_key: str # 数据中时间戳的访问路径
value_key: str # 数据中值的访问路径
loss_fn: Callable # 损失函数已注册模态(共 19 个):
| 模态名 | dim | 类型 | Loss |
|---|---|---|---|
| cursor_velocity_2d | 2 | CONTINUOUS | MSE |
| cursor_position_2d | 2 | CONTINUOUS | MSE |
| arm_velocity_2d | 2 | CONTINUOUS | MSE |
| running_speed | 1 | CONTINUOUS | MSE |
| drifting_gratings_orientation | 8 | MULTINOMIAL | CE |
| natural_scenes | 119 | MULTINOMIAL | CE |
| natural_movie_one_frame | 900 | MULTINOMIAL | CE |
| 等... |
核心依赖 (pyproject.toml):
torch~=2.0: 深度学习框架temporaldata>=0.1.3: 时间序列数据结构einops~=0.6.0: 张量操作hydra-core~=1.3.2: 配置管理torchmetrics>=1.6.0: 指标计算pydantic~=2.0: 数据验证
运行时依赖:
lightning: PyTorch Lightning 训练框架torch-optimizer==0.3.0: SparseLamb 优化器wandb: 实验记录brainsets: 神经数据集工具
当前已验证环境 (conda env poyo):
- Python 3.10
- PyTorch 2.10.0+cu128
- Lightning 2.6.1
- brainsets 0.2.1.dev4 (GitHub)