BOLD-Cast is a two-stage autoregressive deep learning framework for reconstructing long-range brain dynamics from short fMRI scans.
This repository contains the official PyTorch implementation of the paper:
Modeling individual-level long-range brain dynamics from short fMRI scans
- Overview
- Repository Structure
- System Requirements
- Installation Guide
- Instructions for Use
- Reproducing the Main Results
- Data Availability
- Code Availability
- Contact
BOLD-Cast addresses a practical challenge in neuroimaging: how to obtain reliable long-range functional brain dynamics from short resting-state fMRI scans.
The framework consists of two stages:
-
Stage I: a graph-based disentanglement module that separates each fMRI sample into:
- cohort-invariant embedding
- subject-specific embedding
-
Stage II: a prompt-based autoregressive forecasting module that combines:
- graph-derived embeddings
- raw sequence features
- timestamp-based prompt embeddings
and predicts future fMRI signals using a frozen language-model backbone.
The model is trained on parcel-level resting-state fMRI signals parcellated using the Craddock CC200 atlas.
BOLD-Cast/
├── Stage I/ # Graph-based disentanglement learning
├── Stage II/ # Autoregressive forecasting module
├── dataset/ # User-prepared input data
├── requirements.txt # Python dependencies
├── README.md # Project documentation
└── [other files]
For small-scale testing or toy examples, BOLD-Cast can be run on a standard workstation.
For full training and reproduction of the main experiments in the manuscript, we recommend:
- NVIDIA GPU with at least 16–32 GB memory
- Sufficient CPU RAM for loading preprocessed parcel-level fMRI data
- Linux/Windows-based environment for large-scale training
The main experiments in the manuscript were run on:
- NVIDIA V100 GPUs (32 GB memory)
- NVIDIA GPU with at least 16–32 GB memory
- Sufficient CPU RAM for loading preprocessed parcel-level fMRI data
- Linux/Windows-based environment for large-scale training
The main experiments in the manuscript were run on:
- Window 11
- Python=3.10
- Cuda=11.8
- PyTorch=2.0.1
Main dependencies are listed in requirements.txt. Typical packages include:
torch
numpy
scipy
scikit-learn
h5py
pandas
matplotlib
transformers
Install from GitHub
git clone https://github.com/CUHK-AIM-Group/BOLD-Cast.git
cd BOLD-Cast
pip install -r requirements.txt
Typical installation time On a normal desktop or workstation with a stable internet connection, installation typically takes:10–20 minutes.
Downloading pretrained language model checkpoints may require additional time depending on network conditions.
BOLD-Cast does not operate directly on raw DICOM or raw NIfTI files.
The expected input to the model is preprocessed parcel-level fMRI time series, obtained after:
- standard fMRI preprocessing
- parcellation using the Craddock CC200 atlas
- construction of sliding-window sequences
For each sliding window:
parcel-wise time series are used as node features
functional connectivity (FC) is computed from Pearson correlation
graph representations are then constructed for Stage I
To use this repository on your own data, you should:
- Obtain approved access to the original neuroimaging dataset.
- Preprocess the fMRI data into parcel-level time series.
- Organize the files in a format compatible with the repository.
- Run Stage I to extract disentangled graph embeddings.
- Run Stage II to forecast future fMRI dynamics.
For Stage I
-
Put the datasets under the folder
./dataset/. -
Important args:
--use_pretrainTest checkpoints in checkpoints--datasetukb hcp-d hcp-ya hcp-a abide--custom_keyNode: node classification
- Training
- `python prepare_data.py
- `python main.py
- Testing
- `use_pretrain == 'True'
-
Put the datasets under the folder
./dataset/. -
Download the large language models from Hugging Face. The default LLM is GPT2(https://huggingface.co/openai-community/gpt2)
For example, if you download and put the LLaMA directory successfully, the directory structure is as follows:
- data_provider
- dataset
- gpt2
- config.json
- flax_model.msgpack
- generation_config.json
- ...
- ...
- run.py
-
Generate timestamp by gpt2, suffixed by {subjectid}.pt,
python generate_timestamp_ukb.py -
Generate timestamp embeddings, and save them along with historical time series, common features, and characteristic features as H5 data, stored in dataset/ukb_input. The training, value, and test datasets need to be generated in three separate batches. python preprocess_ts.py
-
Training python run.py
-
Testing
- `use_pretrain == 'True'
Typical forecasting metrics include:
- MAE
- RMSE
- MAPE
- mPCC
- FCPCC
- FN
Downstream analyses in the manuscript include:
- ASD classification
- sex identification
- cognitive score prediction
Most of the comparison algorithms have been integrated into the models. Some GNN-based models cannot be integrated and are therefore not included.
The data analyzed in this study are available only for bona fide research purposes and require approval from the corresponding data providers. The datasets used in the manuscript include:
- UK Biobank
- ABIDE
- HCP-Young Adult
- HCP-Development
- HCP-Aging
Due to privacy, ethical, and data-use restrictions, these datasets are not redistributed in this repository.
Users must obtain access directly from the corresponding data providers and prepare the data locally before running the code.
We appreciate the following GitHub repos a lot for their valuable code and efforts.
- Time-Series-Library (https://github.com/thuml/Time-Series-Library)
- FPT (https://github.com/DAMO-DI-ML/NeurIPS2023-One-Fits-All)
For questions regarding the code or manuscript, please contact: [Yu Jiang] [yuajiang@cuhk.edu.hk] If you use this code in your research, please cite the corresponding paper.