If you are interested in just the TissueFormer architecture, please consult the source code in model.py.
This framework addresses a fundamental challenge in neuroscience: accurately mapping the spatial organization of cell types across brain regions. By combining single-cell RNA sequencing data with spatial information, we developed a transformer-based model that groups spatially proximate cells and predicts their anatomical locations within the brain.
- Hierarchical Transformer Architecture: Combines BERT-based gene expression encoding with set transformer layers for spatial group processing
- Spatial Grouping Strategies: Implements both hexagonal grid and k-nearest neighbor sampling for creating spatially coherent cell groups
- Multi-scale Learning: Learns representations at both single-cell and spatial group levels
- Comprehensive Benchmarking: Includes comparisons against Random Forest and Logistic Regression baselines
- Flexible Configuration: Uses Hydra for experiment management and hyperparameter tuning
- Python 3.11+
- CUDA-compatible GPU
- Micromamba or Conda
- Clone the repository:
git clone <repository-url>
cd brain-annotation- Create and activate the environment:
# Create environment named 'spatial_transformer' (or your preferred name)
source create_env.sh spatial_transformerThis will install all required dependencies including:
- PyTorch with CUDA support
- Transformers, Datasets (HuggingFace)
- Hydra for configuration management
- Scientific computing libraries (NumPy, SciPy, scikit-learn)
- Visualization tools (Matplotlib, Seaborn)
- Weights & Biases for experiment tracking
The pipeline expects single-cell RNA-seq data in .h5ad (AnnData) format with:
- Gene expression matrix: Raw counts in
adata.X - Spatial coordinates: 3D coordinates in
adata.obsm['CCF_streamlines'] - Cell type annotations: In
adata.obs['H3_type'](optional, for analysis) - Area annotations: In
adata.obs['CCFano'](Allen Brain Atlas annotation IDs)
- Convert MATLAB files to H5AD format (if starting from .mat files):
python data/mat_to_h5.py \
--input_dir /path/to/mat/files \
--output_dir /path/to/h5ad/output \
--force # Overwrite existing files- Tokenize gene expression data:
python data/tokenize_cells.py \
--h5ad_data_directory /path/to/h5ad/files \
--output_directory /path/to/tokenized/output \
--output_prefix train_test_barseq \
--cv-fold 0 \ # Cross-validation fold (0-3 for train, >=4 for test set)
--raw-counts # Include raw counts for benchmarkingThe tokenization process:
- Normalizes gene expression by total counts per cell
- Ranks genes by expression level
- Converts to token sequences compatible with transformer models
- Adds spatial coordinates and metadata
- Calculate class weights (optional, for imbalanced datasets):
python data/calculate_class_weights.py \
data.dataset_path=/path/to/tokenized/dataset \
data.label_key=area_label \
weighting.method=balancedThe model consists of three main components:
- BERT Encoder: Processes tokenized gene expression for each cell
- Spatial Grouping: Groups nearby cells using configurable strategies
- Set Transformer: Aggregates information from spatial groups
Input: Group of spatially proximate cells
↓
BERT encoding (per cell)
↓
Position encoding (optional)
↓
Set Transformer layers
↓
Mean pooling
↓
Classification head
↓
Output: Brain area prediction
- Tessellates the tissue with hexagonal grid
- Ensures uniform spatial coverage
- Configurable hex size based on cell density
- Uses KD-tree for efficient nearest neighbor search
- Adaptively expands search radius when needed
- Suitable for irregular tissue shapes
Train a model with default settings:
python train.pyCustomize training using Hydra's override syntax:
# Train with hexagonal spatial grouping
python train.py \
data.sampling.strategy=hex \
data.group_size=32 \
training.learning_rate=1e-4 \
training.num_train_epochs=15pretrained_type: Model initialization strategy"none": Train from scratch"bert_only": Use pretrained BERT, train set transformer"full": Load complete pretrained model"single-cell": Single-cell baseline without spatial grouping
num_set_layers: Number of set transformer layers (default: 4)set_hidden_size: Hidden dimension for set transformer (default: 768)relative_positions.enabled: Use relative position encoding
dataset_path: Path to tokenized datasetgroup_size: Number of cells per spatial group (default: 32)sampling.strategy:"hex"or"random"sampling.hex_scaling: Scaling factor for hexagon sizesampling.max_radius_expansions: Maximum search radius expansions
num_train_epochs: Number of training epochsper_device_train_batch_size: Batch size (divided by group_size)learning_rate: Learning rate (default: 1e-4)warmup_ratio: Fraction of steps for learning rate warmup
For distributed training across multiple GPUs:
accelerate launch --multi_gpu --num_processes 4 train.pyThe training script automatically evaluates on validation and test sets. Results include:
- Per-class precision, recall, and F1 scores
- Confusion matrices
- Spatial distribution of predictions
Compare against Random Forest and Logistic Regression baselines:
python benchmarks.py \
run_bulk_expression_rf=true \
run_bulk_expression_lr=true \
run_h3type_rf=true \
run_h3type_lr=trueBenchmark features:
- Bulk expression: Average gene expression per spatial group
- H3 type composition: Histogram of cell types per group
Visualize hexagonal grid sampling:
python train.py \
data.sampling.strategy=hex \
visualize_hex_grid=trueGroup cells within specific categories (e.g., same animal or cell type):
python train.py \
data.sampling.group_within_keys=[animal_name,H3_type]Enable relative position encoding:
python train.py \
model.relative_positions.enabled=true \
model.relative_positions.encoding_type=sinusoidal \
model.relative_positions.encoding_dim=48Handle imbalanced datasets:
python train.py \
data.class_weights.enabled=true \
data.class_weights.path=/path/to/weights.npyConfigure Weights & Biases logging:
python train.py \
wandb.project=my_project \
wandb.entity=my_team \
wandb.name=experiment_nameTraining produces several output files:
model/: Trained model checkpointstrainer_state.json: Training state for resumingall_results.json: Evaluation metricstest_brain_predictions_cells.npy: Test set predictions with metadatahex_grid_sampling_gs_32.png: Visualization of spatial grouping (if enabled)
-
CUDA Out of Memory
- Reduce
per_device_train_batch_size - Decrease
data.group_size - Enable gradient checkpointing
- Reduce
-
Slow Training
- Increase
dataloader_num_workers - Use
data.sampling.strategy=hex(faster than random) - Enable mixed precision:
fp16=true
- Increase
-
Poor Performance
- Increase
training.num_train_epochs(15-20 recommended) - Tune
data.group_sizebased on cell density - Enable class weighting for imbalanced data
- Increase
For quick iteration during development:
python train.py debug=trueThis limits dataset size and disables wandb logging.
If you use this code in your research, please cite:
[Citation information to be added]This project is licensed under the MIT License. See LICENSE file for details.
This work builds upon:
- Geneformer for gene expression tokenization
- HuggingFace Transformers for model implementations
- Allen Brain Atlas for anatomical reference standards
For questions or issues, please open a GitHub issue or contact the maintainers.