-
Notifications
You must be signed in to change notification settings - Fork 0
Description
Upgrading from ESM2-650M to ESM2-3B
This document outlines what it would take to use the larger 3B parameter ESM2 model (esm2_t36_3B_UR50D) instead of the current 650M model (esm2_t33_650M_UR50D).
Good News: Infrastructure Already Supports 3B
The codebase already has the 3B model defined and working! In dnsmex/esm_wrapper.py:95-102:
def esm2_wrapper_of_size(size, device=None):
name_of_size = {
"650M": "esm2_t33_650M_UR50D",
"3B": "esm2_t36_3B_UR50D", # <-- Already supported!
"15B": "esm2_t48_15B_UR50D",
}
return ESMWrapper(name_of_size[size], device=device)The shanehsazzadeh.ipynb notebook already uses the 3B model for its ESM evaluations.
Model Specifications
| Model | Parameters | Layers | Attention Heads | Hidden Dim |
|---|---|---|---|---|
| ESM2-650M | 650M | 33 | 20 | 1,280 |
| ESM2-3B | 2.8B | 36 | 40 | 2,560 |
| ESM2-15B | 15B | 48 | 40 | 5,120 |
GPU Memory Requirements
ESM2-3B Memory Estimates
- Model parameters in fp32: ~11 GB
- Model parameters in fp16: ~5.6 GB
- Inference memory (including activations): ~12-16 GB VRAM for typical antibody sequences (~120 residues)
- Recommended GPU: NVIDIA A100 (40/80GB) or V100 (32GB)
- Minimum GPU: RTX 3090/4090 (24GB) should work for inference
Comparison to 650M
The 650M model fits comfortably on most modern GPUs (~4GB for inference). The 3B model requires ~4x more memory.
Files That Need Updating
The following files have hardcoded "650M" that should be changed to "3B":
Core Pipeline Files (High Priority)
-
dnsmex/magma_helper.py:141- Main MAGMA analysis pipelineesm_wrapper = get_cached_esm_wrapper("650M", use_remote=use_remote_esm) # Change to: "3B"
-
dnsmex/kirby_helper.py:155, 181- Kirby dataset analysisesm_wrapper = get_cached_esm_wrapper("650M") esm_wrapper = get_cached_esm_wrapper("650M", use_remote=use_remote_esm)
-
scripts/magma_unified_pipeline.py:79- Unified pipeline scriptesm_wrapper = get_cached_esm_wrapper(model_size="650M", use_remote=True)
Scripts (Medium Priority)
scripts/magma_validation.py:207-208- Validation scriptscripts/timing_direct_gpu.py:105, 108- Timing benchmarks
Default Parameters (Update defaults)
dnsmex/cached_model_wrappers.py:18, 98- Default model sizednsmex/remote_esm.py:20, 362- Remote evaluator defaults
Documentation/Data (Update after re-running)
REPRODUCIBILITY.md- Update memory requirementsdata/whitehead/MAGMA_PIPELINE_STRUCTURE.md- Various CSV/SVG output files will regenerate
Implementation Steps
Step 1: Verify GPU Resources
Before switching, confirm your remote GPU server (ermine) has sufficient VRAM:
ssh ermine nvidia-smiYou need at least 16GB VRAM, ideally 24GB+ for comfortable margins.
Step 2: Test 3B Model Locally
Run a quick test to verify the 3B model works:
from dnsmex.esm_wrapper import esm2_wrapper_of_size
# This will download the model on first run (~5.6GB)
esm = esm2_wrapper_of_size("3B", device="cuda")
score = esm.pseudo_perplexity("EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYWMSWVRQAPGKGLEWVANIKQDGSEKYYVDSVKGRFTISRDNAKNSLYLQMNSLRAEDTAVYYCAR")
print(f"Perplexity: {score}")Step 3: Update Default Model Size
The cleanest approach is to update the default in dnsmex/cached_model_wrappers.py:
# dnsmex/cached_model_wrappers.py line 18
def __init__(self, model_size: str = "3B"): # Changed from "650M"
# dnsmex/cached_model_wrappers.py line 98
def get_cached_esm_wrapper(model_size: str = "3B", use_remote: bool = True):Then update dnsmex/remote_esm.py:
# Line 20
def __init__(self, model_size: str = "3B"):
# Line 362
def get_remote_esm_evaluator(model_size: str = "3B"):Step 4: Clear Cached Results
The existing ESM scores are cached. You'll need to clear them:
# Remove cached ESM results to force recomputation
rm -rf _ignore/*ESM*Step 5: Re-run Analyses
Re-run the full pipeline to generate new results with ESM2-3B:
python scripts/magma_unified_pipeline.pyPerformance Considerations
Timing Estimates (based on shanehsazzadeh.ipynb data)
- ESM2-650M: ~7.6s per 100 sequences on CUDA
- ESM2-3B: Expect ~4x slower, so ~30s per 100 sequences
For the Whitehead dataset (~1000 sequences), expect:
- 650M: ~76 seconds
- 3B: ~5 minutes
Pseudo-perplexity Computation
The masked_logits method masks each position one at a time, making inference O(n) more expensive. For a 120-residue antibody:
- 650M: ~120 forward passes per sequence
- 3B: Same count, but each pass is slower
Alternative: Environment Variable
For flexibility, you could add an environment variable to select model size:
import os
DEFAULT_ESM_SIZE = os.environ.get("ESM_MODEL_SIZE", "3B")Then set ESM_MODEL_SIZE=650M for quick testing or ESM_MODEL_SIZE=3B for production.
Minimal Changes Summary
For a quick switch to 3B, edit these 4 files:
dnsmex/cached_model_wrappers.py- lines 18, 98dnsmex/remote_esm.py- lines 20, 362dnsmex/magma_helper.py- line 141dnsmex/kirby_helper.py- lines 155, 181
Total: ~6 line changes to switch the default model.
Output Files → Paper Figures/Tables Correspondence
After switching to ESM2-3B and re-running, these outputs need to be regenerated:
Tables
| Paper Reference | Source Script/Notebook | Output File |
|---|---|---|
| Table 2 (Koenig correlations) | notebooks/dasm_paper/koenig.ipynb |
_output/correlations_koenig.csv |
| Table 3 (Shane correlations) | notebooks/dasm_paper/shanehsazzadeh.ipynb |
_output/correlations_shane.csv |
| Table 4 (MAGMA correlations) | scripts/magma_unified_model_correlation_analysis.py |
data/whitehead/processed/magma_correlation_table.tex |
| Table 5 (Timing) | scripts/timing_direct_gpu.py + scripts/make_timing_table.py |
data/whitehead/processed/timing_table.tex |
Figures
| Paper Figure | Source Notebook/Script | Output File → FIGURES_DIR |
|---|---|---|
| Fig 1 (NT process in LLMs) | notebooks/dasm_paper/nt_process_in_llms.ipynb |
fig1.pdf (composite) |
| Fig 3 (Koenig DASM expression) | notebooks/dasm_paper/koenig.ipynb |
koenig_dasm_expression.svg → koenig_expr_compare.pdf |
| Fig S2 (Koenig AbLang SHM colors) | notebooks/dasm_paper/koenig.ipynb |
koenig_ablang_shm_colors.svg |
| Fig S3 (Koenig DASM light) | notebooks/dasm_paper/koenig.ipynb |
koenig_dasm_expression_light.svg |
| Fig S4 (Koenig neighbor comparison) | notebooks/dasm_paper/koenig.ipynb |
koenig_neighbor_comparison.svg |
| Fig S5 (Koenig heavy heatmap) | notebooks/dasm_paper/koenig.ipynb |
koenig_heavy_heatmap.svg |
| Fig S6 (PCP predictions) | Snakemake + notebooks/dasm_paper/perplexity.ipynb |
fig3.pdf |
| Fig S7 (Shanehsazzadeh) | notebooks/dasm_paper/shanehsazzadeh.ipynb |
shanehsazzadeh.svg |
| Fig S8 (MAGMA scatter) | scripts/magma_unified_model_correlation_analysis.py |
data/whitehead/processed/magma_unified_model_correlations.svg |
MAGMA Pipeline (Critical for Table 4 & Fig S8)
The MAGMA outputs are generated by running these scripts in order:
python scripts/magma_unified_pipeline.py # Scores with ESM (currently 650M)
python scripts/magma_unified_model_correlation_analysis.py # Generates table + figureOutput files:
data/whitehead/processed/magma_correlation_table.csv→ Table 4 datadata/whitehead/processed/magma_correlation_table.tex→ LaTeX for Table 4data/whitehead/processed/magma_unified_model_correlations.svg→ Figure S8
Timing Pipeline
Timing benchmarks (Table 5) are generated by:
python scripts/timing_direct_gpu.py # Run on GPU server
python scripts/make_timing_table.py # Generate LaTeX tableOutput files:
data/whitehead/processed/direct_timing_results_cuda_1.csvdata/whitehead/processed/direct_timing_results_cpu.csvdata/whitehead/processed/timing_table.tex
Note: ESM2-3B will be ~4x slower, so timing table values will change significantly.
Complete Re-generation Checklist
After switching to ESM2-3B:
- Clear ESM caches:
rm -rf _ignore/*ESM* - Re-run
notebooks/dasm_paper/koenig.ipynb→ Tables 2, Figs 3, S2-S5 - Re-run
notebooks/dasm_paper/shanehsazzadeh.ipynb→ Table 3, Fig S7 - Re-run
scripts/magma_unified_pipeline.py→ scores unified dataset - Re-run
scripts/magma_unified_model_correlation_analysis.py→ Table 4, Fig S8 - Re-run
scripts/timing_direct_gpu.pyon GPU server → timing data - Re-run
scripts/make_timing_table.py→ Table 5 - Convert SVGs to PDFs and copy to
FIGURES_DIR