Skip to content

Commit 0ee6eb1

Browse files
committed
Update code
1 parent f7bf72f commit 0ee6eb1

23 files changed

Lines changed: 1870 additions & 1592 deletions

README.md

Lines changed: 113 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,109 +1,154 @@
1-
# MapPFN: Learning Causal Perturbation Maps in Context
1+
<h1 align="center">MapPFN: Learning Causal Perturbation Maps in Context</h1>
22

3-
This repository contains the code, configurations, and data processing scripts to reproduce the experiments for **MapPFN**, a prior-data fitted network (PFN) that uses in-context learning to predict perturbation effects in unseen biological contexts.
3+
<p align="center">
4+
<a href="https://arxiv.org/abs/2601.21092"><img src="https://img.shields.io/badge/arXiv-b31b1b?style=for-the-badge&logo=arxiv" alt="arXiv"/></a>
5+
<a href="https://marvinsxtr.github.io/MapPFN"><img src="https://img.shields.io/badge/Project_Page-007ec6?style=for-the-badge&logo=htmx&logoColor=white" alt="Project Page"/></a>
6+
<a href="https://huggingface.co/marvinsxtr/MapPFN"><img src="https://img.shields.io/badge/Models-f5a623?style=for-the-badge&logo=huggingface&logoColor=white" alt="Models"/></a>
7+
<a href="https://huggingface.co/datasets/marvinsxtr/MapPFN"><img src="https://img.shields.io/badge/Datasets-f5a623?style=for-the-badge&logo=huggingface&logoColor=white" alt="Datasets"/></a>
8+
</p>
49

5-
![MapPFN Overview](assets/overview.png)
10+
**MapPFN** is a prior-data fitted network (PFN) that uses in-context learning to predict perturbation effects in unseen biological contexts.
11+
12+
<div align="center">
13+
<img src="assets/overview.png" width="80%">
14+
<p><em><strong>MapPFN overview.</strong> During pre-training, synthetic causal models are drawn to generate observational and interventional distributions. MapPFN meta-learns to map between pre- and post-perturbation distributions across many causal structures. At inference, it predicts cell-level post-perturbation distributions in one forward pass through amortized inference.</em></p>
15+
</div>
616

717
## Abstract
818

9-
Planning effective interventions in biological systems requires treatment-effect models that adapt to unseen biological contexts by identifying their specific underlying mechanisms. Yet single-cell perturbation datasets span only a handful of biological contexts, and existing methods cannot leverage new interventional evidence at inference time to adapt beyond their training data. To meta-learn a perturbation effect estimator, we present MapPFN, a prior-data fitted network (PFN) pretrained on synthetic data generated from a prior over causal perturbations. Given a set of experiments, MapPFN uses in-context learning to predict post-perturbation distributions, without gradient-based optimization. Despite being pretrained on *in silico* gene knockouts alone, MapPFN identifies differentially expressed genes, matching the performance of models trained on real single-cell data.
19+
Planning effective interventions in biological systems requires treatment-effect models that adapt to unseen biological contexts by identifying their specific underlying mechanisms. Yet single-cell perturbation datasets span only a handful of biological contexts, and existing methods cannot leverage new interventional evidence at inference time to adapt beyond their training data. To meta-learn a perturbation effect estimator, we present MapPFN, a prior-data fitted network (PFN) pre-trained on synthetic data generated from a prior over causal perturbations. Given a set of experiments, MapPFN uses in-context learning to predict post-perturbation distributions, without gradient-based optimization. Despite being pre-trained on *in silico* gene knockouts alone, MapPFN identifies differentially expressed genes, matching the performance of models trained on real single-cell data.
1020

11-
## Table of Contents
21+
## Setup
1222

13-
- [Setup](#setup)
14-
- [Repository Structure](#repository-structure)
15-
- [Usage](#usage)
16-
- [Data Generation](#data-generation)
17-
- [Training](#training)
18-
- [Dependencies](#dependencies)
23+
A Docker image and devcontainer configuration are provided with all dependencies:
1924

20-
## Setup
25+
```bash
26+
docker run --rm -it --gpus all -v .:/srv/repo ghcr.io/marvinsxtr/mappfn:latest bash
27+
```
28+
29+
<details>
30+
<summary>VSCode & Slurm</summary>
2131

22-
A `Dockerfile` is provided for containerized environments. The image includes all dependencies and can be used with Docker or Apptainer on HPC clusters.
32+
Use the [Remote - Containers](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers) extension to open the devcontainer locally, or connect to a remote tunnel by replacing `bash` with `code tunnel`.
2333

24-
Logging to WandB is optional for local jobs but mandatory for jobs submitted to the cluster. Create a `.env` file in the root of the repository with:
34+
This setup also works with Apptainer on Slurm clusters. See the [ml-project-template](https://github.com/marvinsxtr/ml-project-template) for instructions.
35+
36+
</details>
37+
38+
<details>
39+
<summary>WandB logging (optional)</summary>
40+
41+
Create a `.env` file in the repository root:
2542

2643
```bash
2744
WANDB_API_KEY=your_api_key
2845
WANDB_ENTITY=your_entity
2946
WANDB_PROJECT=your_project_name
3047
```
3148

32-
## Repository Structure
49+
</details>
3350

34-
```
35-
MapPFN/
36-
├── map_pfn/
37-
│ ├── configs/ # Hydra-zen configuration files
38-
│ ├── data/ # Dataset classes and data generation
39-
│ │ ├── linear_scm.py # Linear SCM data generation
40-
│ │ ├── sergio_dataset.py # SERGIO GRN simulation
41-
│ │ └── perturbation_dataset.py
42-
│ ├── models/ # Model architectures
43-
│ │ ├── map_pfn.py # MapPFN model
44-
│ │ └── mmdit.py # MMDiT architecture
45-
│ ├── eval/ # Evaluation metrics
46-
│ ├── loss/ # Loss functions (CFM)
47-
│ ├── scripts/ # Training and data generation scripts
48-
│ │ ├── train.py
49-
│ │ └── generate_data.py
50-
│ ├── train/ # Training utilities
51-
│ └── utils/ # Helper functions
52-
├── baselines/
53-
│ ├── condot/ # Conditional Optimal Transport baseline
54-
│ └── metafm/ # Meta Flow Matching baseline
55-
└── datasets/ # Generated datasets (gitignored)
56-
```
51+
## Data
5752

58-
## Usage
53+
Download pre-trained [weights](https://huggingface.co/marvinsxtr/MapPFN) and [datasets](https://huggingface.co/datasets/marvinsxtr/MapPFN) from Hugging Face:
5954

60-
### Data Generation
55+
```python
56+
from huggingface_hub import hf_hub_download
57+
58+
hf_hub_download("marvinsxtr/MapPFN", "model.ckpt", local_dir="checkpoints", repo_type="model")
59+
hf_hub_download("marvinsxtr/MapPFN", "frangieh.h5ad", local_dir="datasets/single_cell", repo_type="dataset")
60+
hf_hub_download("marvinsxtr/MapPFN", "papalexi.h5ad", local_dir="datasets/single_cell", repo_type="dataset")
61+
hf_hub_download("marvinsxtr/MapPFN", "sergio.h5ad", local_dir="datasets/synthetic", repo_type="dataset")
62+
```
6163

62-
Generate synthetic datasets from linear SCMs or biological priors:
64+
<details>
65+
<summary>Preprocessing & generation</summary>
6366

67+
Preprocess single-cell datasets:
6468
```bash
65-
# Generate linear SCM data
66-
python map_pfn/scripts/generate_data.py cfg=linear_scm
69+
python map_pfn/scripts/process_sc_data.py
70+
```
6771

68-
# Generate SERGIO GRN data
69-
python map_pfn/scripts/generate_data.py cfg=sergio_grn
72+
Generate synthetic datasets:
73+
```bash
74+
python map_pfn/scripts/generate_data.py cfg=linear # Linear SCMs
75+
python map_pfn/scripts/generate_data.py cfg=sergio # Biological prior
7076
```
7177

72-
### Training
78+
</details>
79+
80+
## Inference
81+
82+
```python
83+
from map_pfn.eval.evaluate import load_model
84+
85+
trainer, module, datamodule = load_model(
86+
method="map_pfn",
87+
checkpoint_path="checkpoints/model.ckpt",
88+
dataset_path="datasets/single_cell/frangieh.h5ad",
89+
)
90+
preds = trainer.predict(module, datamodule=datamodule)
91+
```
7392

74-
Train MapPFN or baselines using the provided configurations:
93+
## Fine-tuning
7594

95+
Fine-tune from a pre-trained checkpoint:
7696
```bash
77-
# Train MapPFN on linear SCMs
78-
python map_pfn/scripts/train.py cfg=map_pfn_scm
79-
````
97+
python map_pfn/scripts/train.py \
98+
cfg=map_pfn_rna \
99+
cfg/datamodule=frangieh_finetune \
100+
cfg.load_checkpoint=checkpoints/model.ckpt \
101+
cfg.trainer.val_check_interval=500 \
102+
cfg.trainer.callbacks.2.max_steps=3000 \
103+
cfg/wandb=base
104+
```
80105

81-
Available model configs: `map_pfn_scm`, `map_pfn_rna`, `condot_scm`, `condot_rna`, `metafm_scm`, `metafm_rna`
106+
## Pre-training
82107

108+
Train MapPFN from scratch:
83109
```bash
84-
# Run distributed sweep on Slurm
85-
python map_pfn/scripts/train.py cfg/job=methods_scm
110+
python map_pfn/scripts/train.py cfg=map_pfn_rna
86111
```
87112

88-
Available sweep configs: `methods_scm`, `methods_sergio`, `map_pfn_scm`, `map_pfn_sergio`
113+
## Configuration
89114

90-
See [map_pfn/configs/train/config_stores.py](map_pfn/configs/train/config_stores.py) for all available configurations. This project uses [hydra-zen](https://github.com/mit-ll-responsible-ai/hydra-zen) for configuration management. Override parameters via command line:
115+
This project uses [hydra-zen](https://github.com/mit-ll-responsible-ai/hydra-zen) for configuration. Display all available options:
91116

92117
```bash
93-
python map_pfn/scripts/train.py cfg=map_pfn_scm cfg.datamodule.batch_size=64
118+
python map_pfn/scripts/train.py --help
119+
python map_pfn/scripts/generate_data.py --help
94120
```
95121

96-
## Dependencies
122+
## Repository Structure
123+
124+
```
125+
MapPFN/
126+
├── map_pfn/
127+
│ ├── configs/ # Hydra-zen configuration
128+
│ ├── data/ # Datasets and data generation
129+
│ ├── models/ # MapPFN and MMDiT architecture
130+
│ ├── eval/ # Evaluation metrics
131+
│ ├── loss/ # Loss functions (CFM)
132+
│ ├── scripts/ # Training and data generation
133+
│ ├── train/ # Training utilities
134+
│ └── utils/ # Helpers
135+
├── baselines/
136+
│ ├── condot/ # Conditional Optimal Transport
137+
│ └── metafm/ # Meta Flow Matching
138+
└── datasets/ # Generated datasets (gitignored)
139+
```
140+
141+
## Citation
142+
143+
```bibtex
144+
@article{sextro2026mappfn,
145+
title = {{MapPFN}: Learning Causal Perturbation Maps in Context},
146+
author = {Sextro, Marvin and K\l{}os, Weronika and Dernbach, Gabriel},
147+
journal = {arXiv preprint arXiv:2601.21092},
148+
year = {2026}
149+
}
150+
```
97151

98-
This project builds on the following open-source libraries:
152+
## Contributing
99153

100-
- [JAX](https://github.com/google/jax) - High-performance numerical computing
101-
- [Equinox](https://github.com/patrick-kidger/equinox) - Neural networks in JAX
102-
- [Hydra-zen](https://github.com/mit-ll-responsible-ai/hydra-zen) - Configuration management
103-
- [Diffrax](https://github.com/patrick-kidger/diffrax) - Differential equation solvers in JAX
104-
- [OTT-JAX](https://github.com/ott-jax/ott) - Optimal transport tools
105-
- [AnnData](https://github.com/scverse/anndata) - Annotated data for single-cell analysis
106-
- [Scanpy](https://github.com/scverse/scanpy) - Single-cell analysis in Python
107-
- [Pertpy](https://github.com/theislab/pertpy) - Perturbation analysis tools
108-
- [sergio_rs](https://github.com/rainx0r/sergio_rs) - Single-cell expression simulator
109-
- [grn-paper](https://github.com/maguirre1/grn-paper) - Gene regulatory network sampling
154+
If you have any feedback, questions, or ideas, please [open an issue](https://github.com/marvinsxtr/MapPFN/issues) or reach out via [email](mailto:m.kleine.sextro@tu-berlin.de).

baselines/condot/condot_module.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,4 +327,8 @@ def test_step(self, batch, batch_idx) -> dict[str, np.ndarray]:
327327
BatchKeys.TREATMENT: batch[BatchKeys.TREATMENT].detach().cpu().numpy().squeeze(1),
328328
BatchKeys.CONTEXT_ID: np.asarray(batch[BatchKeys.CONTEXT_ID]),
329329
BatchKeys.TREATMENT_ID: np.asarray(batch[BatchKeys.TREATMENT_ID]),
330-
}
330+
}
331+
332+
def predict_step(self, batch, batch_idx):
333+
"""Transport samples and return predictions with metadata."""
334+
return self.test_step(batch, batch_idx)

baselines/metafm/gnn.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,9 @@ def __post_init__(self):
5656
input_size, self.D, self.num_hidden_decoder, self.num_layers_decoder
5757
)
5858

59-
self.temporal_freqs = (
60-
torch.arange(1, self.num_temporal_freqs + 1, device="cuda") * torch.pi
59+
self.register_buffer(
60+
"temporal_freqs",
61+
torch.arange(1, self.num_temporal_freqs + 1) * torch.pi,
6162
)
6263
else:
6364
input_size = (
@@ -77,13 +78,14 @@ def __post_init__(self):
7778
input_size, self.D, self.num_hidden_decoder, self.num_layers_decoder
7879
)
7980

80-
self.temporal_freqs = (
81-
torch.arange(1, self.num_temporal_freqs + 1, device="cuda") * torch.pi
81+
self.register_buffer(
82+
"temporal_freqs",
83+
torch.arange(1, self.num_temporal_freqs + 1) * torch.pi,
8284
)
83-
84-
self.B = (
85-
torch.randn((self.D, self.num_spatial_samples), device="cuda")
86-
* self.spatial_feat_scale
85+
86+
self.register_buffer(
87+
"B",
88+
torch.randn((self.D, self.num_spatial_samples)) * self.spatial_feat_scale,
8789
)
8890

8991
def embed_source(self, source_samples, cond=None):

baselines/metafm/metafm_module.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(
3131
num_treat_conditions: int = None,
3232
num_cell_conditions: int = None,
3333
base: str = "source",
34-
integrate_time_steps: int = 500,
34+
integrate_time_steps: int = 100,
3535
):
3636
super().__init__()
3737
self.save_hyperparameters()
@@ -383,4 +383,8 @@ def test_step(self, batch, batch_idx) -> dict[str, np.ndarray]:
383383
BatchKeys.TREATMENT: batch[BatchKeys.TREATMENT].detach().cpu().numpy().squeeze(1),
384384
BatchKeys.CONTEXT_ID: np.asarray(batch[BatchKeys.CONTEXT_ID]),
385385
BatchKeys.TREATMENT_ID: np.asarray(batch[BatchKeys.TREATMENT_ID]),
386-
}
386+
}
387+
388+
def predict_step(self, batch, batch_idx):
389+
"""Transport samples and return predictions with metadata."""
390+
return self.test_step(batch, batch_idx)

0 commit comments

Comments
 (0)