A novel approach for generating images from brain activity using fMRI data and diffusion models. This project implements a brain adapter that integrates with Stable Diffusion to decode visual imagery from neural signals.
The Brain Adapter leverages the IP-Adapter architecture to inject brain-derived conditioning tokens into the diffusion process, enabling direct fMRI-to-image generation. The system learns to map brain activity patterns to visual features, allowing reconstruction of viewed or imagined images.
brain_decoding/
├── README.md # This file
├── train_brain_adapter.py # Main training script
├── decode_brain_adapter.py # Inference and decoding script
├── plot_brain_adapter.py # Visualization utilities
├── metric_brain_adapter.py # Evaluation metrics
├── acc_config.yaml # Accelerate configuration
├── brain_adapter/ # Core implementation
│ ├── model.py # NeuroAdapter and GuidanceGenerator
│ ├── dataset.py # NSD dataset handling
│ ├── utils.py # Utility functions
│ └── ip_adapter/ # IP-Adapter components
└── whole_brain_encoder/ # Brain encoder for evaluation
conda env create -f env.yml
conda activate brain- git clone from whole_brain_encoder
- create a
__init__.pyin whole_brain_encoder - replace
import xxxwith inimport whole_brain_encoder.xxxin whole_brain_encoder - create a folder called
checkpoints - go to
/engram/nklab/algonauts/ethan/whole_brain_encoder/checkpoints - organize your folders and copy & paste everything from
/nsd_test/dinov2_q_transformer/schaefer/*
Train a brain adapter model using fMRI data:
# Basic training with linear projection
accelerate launch --config_file acc_config.yaml train_brain_adapter.py \
--learning_rate 1e-04 \
--num_train_epochs 100 \
--train_batch_size 8 \
--subject_id 1 \
--topk 100 \
--condition_dim 768 \
--num_decoder_queries 50 \
--sub_approach linear_projection \
--wandb
# Advanced training with transformer decoder
accelerate launch --config_file acc_config.yaml train_brain_adapter.py \
--learning_rate 1e-04 \
--num_train_epochs 100 \
--train_batch_size 16 \
--subject_id 1 \
--topk 100 \
--condition_dim 768 \
--num_decoder_queries 50 \
--sub_approach transformer_decoder \
--wandbGenerate images from trained models:
# Decode subset of test images
python decode_brain_adapter.py \
--model_weights_dir brain_adapter/model_weights/07_26_2025-22_29 \
--saved_epochs 100 \
--start_idx 0 \
--end_idx 16 \
--save_all_candidates \
--num_predictions 8 \
--subject_id 1
# Decode full test dataset
python decode_brain_adapter.py \
--model_weights_dir brain_adapter/model_weights/07_26_2025-22_29 \
--saved_epochs 100 \
--eval_full_dataset \
--batch_size 1 \
--subject_id 1Compute comprehensive quality metrics:
# Evaluate subset results
python metric_brain_adapter.py \
--model_weights_dir brain_adapter/model_weights/07_26_2025-22_29 \
--saved_epochs 100 \
--evaluation_mode subset
# Evaluate full dataset results
python metric_brain_adapter.py \
--model_weights_dir brain_adapter/model_weights/07_26_2025-22_29 \
--saved_epochs 100 \
--evaluation_mode fullGenerate comparison plots:
# Plot all samples
python plot_brain_adapter.py \
brain_adapter/decoded_stimuli/07_26_2025-22_29/epoch_100/subset
# Plot specific range
python plot_brain_adapter.py \
brain_adapter/decoded_stimuli/07_26_2025-22_29/epoch_100/subset \
--start_idx 0 --end_idx 10 \
--plot_type individual--topk: Number of brain parcels to use (default: 100)--condition_dim: Conditioning token dimension (default: 768)--sub_approach: Architecture type (linear_projectionortransformer_decoder)--num_decoder_queries: Number of transformer queries (default: 50)--learning_rate: Training learning rate (default: 1e-4)--subject_id: NSD subject ID (1-8, default: 1)
brain_adapter/
├── model_weights/
│ └── MM_DD_YYYY-HH_MM/
│ └── checkpoint-{epoch}/
├── decoded_stimuli/
│ └── model_name/
│ └── epoch_100/
│ ├── subset/
│ │ ├── sample_000001.npz
│ │ ├── evaluation_metadata.json
│ │ └── metric_subset.json
│ └── full/
└── plotted_stimuli/
└── model_name/
└── epoch_100/
└── subset/
├── individual/
├── grid_comparison.png
└── plot_summary.json
This project is licensed under the MIT License