|
1 | | -# MapPFN: Learning Causal Perturbation Maps in Context |
| 1 | +<h1 align="center">MapPFN: Learning Causal Perturbation Maps in Context</h1> |
2 | 2 |
|
3 | | -This repository contains the code, configurations, and data processing scripts to reproduce the experiments for **MapPFN**, a prior-data fitted network (PFN) that uses in-context learning to predict perturbation effects in unseen biological contexts. |
| 3 | +<p align="center"> |
| 4 | + <a href="https://arxiv.org/abs/2601.21092"><img src="https://img.shields.io/badge/arXiv-b31b1b?style=for-the-badge&logo=arxiv" alt="arXiv"/></a> |
| 5 | + <a href="https://marvinsxtr.github.io/MapPFN"><img src="https://img.shields.io/badge/Project_Page-007ec6?style=for-the-badge&logo=htmx&logoColor=white" alt="Project Page"/></a> |
| 6 | + <a href="https://huggingface.co/marvinsxtr/MapPFN"><img src="https://img.shields.io/badge/Models-f5a623?style=for-the-badge&logo=huggingface&logoColor=white" alt="Models"/></a> |
| 7 | + <a href="https://huggingface.co/datasets/marvinsxtr/MapPFN"><img src="https://img.shields.io/badge/Datasets-f5a623?style=for-the-badge&logo=huggingface&logoColor=white" alt="Datasets"/></a> |
| 8 | +</p> |
4 | 9 |
|
5 | | - |
| 10 | +**MapPFN** is a prior-data fitted network (PFN) that uses in-context learning to predict perturbation effects in unseen biological contexts. |
| 11 | + |
| 12 | +<div align="center"> |
| 13 | + <img src="assets/overview.png" width="80%"> |
| 14 | + <p><em><strong>MapPFN overview.</strong> During pre-training, synthetic causal models are drawn to generate observational and interventional distributions. MapPFN meta-learns to map between pre- and post-perturbation distributions across many causal structures. At inference, it predicts cell-level post-perturbation distributions in one forward pass through amortized inference.</em></p> |
| 15 | +</div> |
6 | 16 |
|
7 | 17 | ## Abstract |
8 | 18 |
|
9 | | -Planning effective interventions in biological systems requires treatment-effect models that adapt to unseen biological contexts by identifying their specific underlying mechanisms. Yet single-cell perturbation datasets span only a handful of biological contexts, and existing methods cannot leverage new interventional evidence at inference time to adapt beyond their training data. To meta-learn a perturbation effect estimator, we present MapPFN, a prior-data fitted network (PFN) pretrained on synthetic data generated from a prior over causal perturbations. Given a set of experiments, MapPFN uses in-context learning to predict post-perturbation distributions, without gradient-based optimization. Despite being pretrained on *in silico* gene knockouts alone, MapPFN identifies differentially expressed genes, matching the performance of models trained on real single-cell data. |
| 19 | +Planning effective interventions in biological systems requires treatment-effect models that adapt to unseen biological contexts by identifying their specific underlying mechanisms. Yet single-cell perturbation datasets span only a handful of biological contexts, and existing methods cannot leverage new interventional evidence at inference time to adapt beyond their training data. To meta-learn a perturbation effect estimator, we present MapPFN, a prior-data fitted network (PFN) pre-trained on synthetic data generated from a prior over causal perturbations. Given a set of experiments, MapPFN uses in-context learning to predict post-perturbation distributions. Pre-trained on *in silico* gene knockouts alone, MapPFN identifies differentially expressed genes on par with models trained on real single-cell data. Fine-tuned, it consistently outperforms all baselines across downstream datasets. |
10 | 20 |
|
11 | | -## Table of Contents |
| 21 | +## Setup |
12 | 22 |
|
13 | | -- [Setup](#setup) |
14 | | -- [Repository Structure](#repository-structure) |
15 | | -- [Usage](#usage) |
16 | | - - [Data Generation](#data-generation) |
17 | | - - [Training](#training) |
18 | | -- [Dependencies](#dependencies) |
| 23 | +A Docker image and devcontainer configuration are provided with all dependencies: |
19 | 24 |
|
20 | | -## Setup |
| 25 | +```bash |
| 26 | +docker run --rm -it --gpus all -v .:/srv/repo ghcr.io/marvinsxtr/mappfn:latest bash |
| 27 | +``` |
| 28 | + |
| 29 | +<details> |
| 30 | +<summary>VSCode & Slurm</summary> |
21 | 31 |
|
22 | | -A `Dockerfile` is provided for containerized environments. The image includes all dependencies and can be used with Docker or Apptainer on HPC clusters. |
| 32 | +Use the [Remote - Containers](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers) extension to open the devcontainer locally, or connect to a remote tunnel by replacing `bash` with `code tunnel`. |
23 | 33 |
|
24 | | -Logging to WandB is optional for local jobs but mandatory for jobs submitted to the cluster. Create a `.env` file in the root of the repository with: |
| 34 | +This setup also works with Apptainer on Slurm clusters. See the [ml-project-template](https://github.com/marvinsxtr/ml-project-template) for instructions. |
| 35 | + |
| 36 | +</details> |
| 37 | + |
| 38 | +<details> |
| 39 | +<summary>WandB logging (optional)</summary> |
| 40 | + |
| 41 | +Create a `.env` file in the repository root: |
25 | 42 |
|
26 | 43 | ```bash |
27 | 44 | WANDB_API_KEY=your_api_key |
28 | 45 | WANDB_ENTITY=your_entity |
29 | 46 | WANDB_PROJECT=your_project_name |
30 | 47 | ``` |
31 | 48 |
|
32 | | -## Repository Structure |
| 49 | +</details> |
33 | 50 |
|
34 | | -``` |
35 | | -MapPFN/ |
36 | | -├── map_pfn/ |
37 | | -│ ├── configs/ # Hydra-zen configuration files |
38 | | -│ ├── data/ # Dataset classes and data generation |
39 | | -│ │ ├── linear_scm.py # Linear SCM data generation |
40 | | -│ │ ├── sergio_dataset.py # SERGIO GRN simulation |
41 | | -│ │ └── perturbation_dataset.py |
42 | | -│ ├── models/ # Model architectures |
43 | | -│ │ ├── map_pfn.py # MapPFN model |
44 | | -│ │ └── mmdit.py # MMDiT architecture |
45 | | -│ ├── eval/ # Evaluation metrics |
46 | | -│ ├── loss/ # Loss functions (CFM) |
47 | | -│ ├── scripts/ # Training and data generation scripts |
48 | | -│ │ ├── train.py |
49 | | -│ │ └── generate_data.py |
50 | | -│ ├── train/ # Training utilities |
51 | | -│ └── utils/ # Helper functions |
52 | | -├── baselines/ |
53 | | -│ ├── condot/ # Conditional Optimal Transport baseline |
54 | | -│ └── metafm/ # Meta Flow Matching baseline |
55 | | -└── datasets/ # Generated datasets (gitignored) |
56 | | -``` |
| 51 | +## Data |
57 | 52 |
|
58 | | -## Usage |
| 53 | +Download pre-trained [weights](https://huggingface.co/marvinsxtr/MapPFN) and [datasets](https://huggingface.co/datasets/marvinsxtr/MapPFN) from Hugging Face: |
59 | 54 |
|
60 | | -### Data Generation |
| 55 | +```python |
| 56 | +from huggingface_hub import hf_hub_download |
| 57 | + |
| 58 | +hf_hub_download("marvinsxtr/MapPFN", "model.ckpt", local_dir="checkpoints", repo_type="model") |
| 59 | +hf_hub_download("marvinsxtr/MapPFN", "frangieh.h5ad", local_dir="datasets/single_cell", repo_type="dataset") |
| 60 | +hf_hub_download("marvinsxtr/MapPFN", "papalexi.h5ad", local_dir="datasets/single_cell", repo_type="dataset") |
| 61 | +hf_hub_download("marvinsxtr/MapPFN", "sergio.h5ad", local_dir="datasets/synthetic", repo_type="dataset") |
| 62 | +``` |
61 | 63 |
|
62 | | -Generate synthetic datasets from linear SCMs or biological priors: |
| 64 | +<details> |
| 65 | +<summary>Preprocessing & generation</summary> |
63 | 66 |
|
| 67 | +Preprocess single-cell datasets: |
64 | 68 | ```bash |
65 | | -# Generate linear SCM data |
66 | | -python map_pfn/scripts/generate_data.py cfg=linear_scm |
| 69 | +python map_pfn/scripts/process_sc_data.py |
| 70 | +``` |
67 | 71 |
|
68 | | -# Generate SERGIO GRN data |
69 | | -python map_pfn/scripts/generate_data.py cfg=sergio_grn |
| 72 | +Generate synthetic datasets: |
| 73 | +```bash |
| 74 | +python map_pfn/scripts/generate_data.py cfg=linear # Linear SCMs |
| 75 | +python map_pfn/scripts/generate_data.py cfg=sergio # Biological prior |
70 | 76 | ``` |
71 | 77 |
|
72 | | -### Training |
| 78 | +</details> |
| 79 | + |
| 80 | +## Inference |
| 81 | + |
| 82 | +```python |
| 83 | +from map_pfn.eval.evaluate import load_model |
| 84 | + |
| 85 | +trainer, module, datamodule = load_model( |
| 86 | + method="map_pfn", |
| 87 | + checkpoint_path="checkpoints/model.ckpt", |
| 88 | + dataset_path="datasets/single_cell/frangieh.h5ad", |
| 89 | +) |
| 90 | +preds = trainer.predict(module, datamodule=datamodule) |
| 91 | +``` |
73 | 92 |
|
74 | | -Train MapPFN or baselines using the provided configurations: |
| 93 | +## Fine-tuning |
75 | 94 |
|
| 95 | +Fine-tune from a pre-trained checkpoint: |
76 | 96 | ```bash |
77 | | -# Train MapPFN on linear SCMs |
78 | | -python map_pfn/scripts/train.py cfg=map_pfn_scm |
79 | | -```` |
| 97 | +python map_pfn/scripts/train.py \ |
| 98 | + cfg=map_pfn_rna \ |
| 99 | + cfg/datamodule=frangieh_finetune \ |
| 100 | + cfg.load_checkpoint=checkpoints/model.ckpt \ |
| 101 | + cfg.trainer.val_check_interval=500 \ |
| 102 | + cfg.trainer.callbacks.2.max_steps=3000 \ |
| 103 | + cfg/wandb=base |
| 104 | +``` |
80 | 105 |
|
81 | | -Available model configs: `map_pfn_scm`, `map_pfn_rna`, `condot_scm`, `condot_rna`, `metafm_scm`, `metafm_rna` |
| 106 | +## Pre-training |
82 | 107 |
|
| 108 | +Train MapPFN from scratch: |
83 | 109 | ```bash |
84 | | -# Run distributed sweep on Slurm |
85 | | -python map_pfn/scripts/train.py cfg/job=methods_scm |
| 110 | +python map_pfn/scripts/train.py cfg=map_pfn_rna |
86 | 111 | ``` |
87 | 112 |
|
88 | | -Available sweep configs: `methods_scm`, `methods_sergio`, `map_pfn_scm`, `map_pfn_sergio` |
| 113 | +## Configuration |
89 | 114 |
|
90 | | -See [map_pfn/configs/train/config_stores.py](map_pfn/configs/train/config_stores.py) for all available configurations. This project uses [hydra-zen](https://github.com/mit-ll-responsible-ai/hydra-zen) for configuration management. Override parameters via command line: |
| 115 | +This project uses [hydra-zen](https://github.com/mit-ll-responsible-ai/hydra-zen) for configuration. Display all available options: |
91 | 116 |
|
92 | 117 | ```bash |
93 | | -python map_pfn/scripts/train.py cfg=map_pfn_scm cfg.datamodule.batch_size=64 |
| 118 | +python map_pfn/scripts/train.py --help |
| 119 | +python map_pfn/scripts/generate_data.py --help |
94 | 120 | ``` |
95 | 121 |
|
96 | | -## Dependencies |
| 122 | +## Repository Structure |
| 123 | + |
| 124 | +``` |
| 125 | +MapPFN/ |
| 126 | +├── map_pfn/ |
| 127 | +│ ├── configs/ # Hydra-zen configuration |
| 128 | +│ ├── data/ # Datasets and data generation |
| 129 | +│ ├── models/ # MapPFN and MMDiT architecture |
| 130 | +│ ├── eval/ # Evaluation metrics |
| 131 | +│ ├── loss/ # Loss functions (CFM) |
| 132 | +│ ├── scripts/ # Training and data generation |
| 133 | +│ ├── train/ # Training utilities |
| 134 | +│ └── utils/ # Helpers |
| 135 | +├── baselines/ |
| 136 | +│ ├── condot/ # Conditional Optimal Transport |
| 137 | +│ └── metafm/ # Meta Flow Matching |
| 138 | +└── datasets/ # Generated datasets (gitignored) |
| 139 | +``` |
| 140 | + |
| 141 | +## Citation |
| 142 | + |
| 143 | +```bibtex |
| 144 | +@article{sextro2026mappfn, |
| 145 | + title = {{MapPFN}: Learning Causal Perturbation Maps in Context}, |
| 146 | + author = {Sextro, Marvin and K\l{}os, Weronika and Dernbach, Gabriel}, |
| 147 | + journal = {arXiv preprint arXiv:2601.21092}, |
| 148 | + year = {2026} |
| 149 | +} |
| 150 | +``` |
97 | 151 |
|
98 | | -This project builds on the following open-source libraries: |
| 152 | +## Contributing |
99 | 153 |
|
100 | | -- [JAX](https://github.com/google/jax) - High-performance numerical computing |
101 | | -- [Equinox](https://github.com/patrick-kidger/equinox) - Neural networks in JAX |
102 | | -- [Hydra-zen](https://github.com/mit-ll-responsible-ai/hydra-zen) - Configuration management |
103 | | -- [Diffrax](https://github.com/patrick-kidger/diffrax) - Differential equation solvers in JAX |
104 | | -- [OTT-JAX](https://github.com/ott-jax/ott) - Optimal transport tools |
105 | | -- [AnnData](https://github.com/scverse/anndata) - Annotated data for single-cell analysis |
106 | | -- [Scanpy](https://github.com/scverse/scanpy) - Single-cell analysis in Python |
107 | | -- [Pertpy](https://github.com/theislab/pertpy) - Perturbation analysis tools |
108 | | -- [sergio_rs](https://github.com/rainx0r/sergio_rs) - Single-cell expression simulator |
109 | | -- [grn-paper](https://github.com/maguirre1/grn-paper) - Gene regulatory network sampling |
| 154 | +If you have any feedback, questions, or ideas, please [open an issue](https://github.com/marvinsxtr/MapPFN/issues) or reach out via [email](mailto:m.kleine.sextro@tu-berlin.de). |
0 commit comments