Skip to content

Upgrading from ESM2-650M to ESM2-3B #1

@matsen

Description

@matsen

Upgrading from ESM2-650M to ESM2-3B

This document outlines what it would take to use the larger 3B parameter ESM2 model (esm2_t36_3B_UR50D) instead of the current 650M model (esm2_t33_650M_UR50D).

Good News: Infrastructure Already Supports 3B

The codebase already has the 3B model defined and working! In dnsmex/esm_wrapper.py:95-102:

def esm2_wrapper_of_size(size, device=None):
    name_of_size = {
        "650M": "esm2_t33_650M_UR50D",
        "3B": "esm2_t36_3B_UR50D",      # <-- Already supported!
        "15B": "esm2_t48_15B_UR50D",
    }
    return ESMWrapper(name_of_size[size], device=device)

The shanehsazzadeh.ipynb notebook already uses the 3B model for its ESM evaluations.

Model Specifications

Model Parameters Layers Attention Heads Hidden Dim
ESM2-650M 650M 33 20 1,280
ESM2-3B 2.8B 36 40 2,560
ESM2-15B 15B 48 40 5,120

GPU Memory Requirements

ESM2-3B Memory Estimates

  • Model parameters in fp32: ~11 GB
  • Model parameters in fp16: ~5.6 GB
  • Inference memory (including activations): ~12-16 GB VRAM for typical antibody sequences (~120 residues)
  • Recommended GPU: NVIDIA A100 (40/80GB) or V100 (32GB)
  • Minimum GPU: RTX 3090/4090 (24GB) should work for inference

Comparison to 650M

The 650M model fits comfortably on most modern GPUs (~4GB for inference). The 3B model requires ~4x more memory.

Files That Need Updating

The following files have hardcoded "650M" that should be changed to "3B":

Core Pipeline Files (High Priority)

  1. dnsmex/magma_helper.py:141 - Main MAGMA analysis pipeline

    esm_wrapper = get_cached_esm_wrapper("650M", use_remote=use_remote_esm)
    # Change to: "3B"
  2. dnsmex/kirby_helper.py:155, 181 - Kirby dataset analysis

    esm_wrapper = get_cached_esm_wrapper("650M")
    esm_wrapper = get_cached_esm_wrapper("650M", use_remote=use_remote_esm)
  3. scripts/magma_unified_pipeline.py:79 - Unified pipeline script

    esm_wrapper = get_cached_esm_wrapper(model_size="650M", use_remote=True)

Scripts (Medium Priority)

  1. scripts/magma_validation.py:207-208 - Validation script
  2. scripts/timing_direct_gpu.py:105, 108 - Timing benchmarks

Default Parameters (Update defaults)

  1. dnsmex/cached_model_wrappers.py:18, 98 - Default model size
  2. dnsmex/remote_esm.py:20, 362 - Remote evaluator defaults

Documentation/Data (Update after re-running)

  1. REPRODUCIBILITY.md - Update memory requirements
  2. data/whitehead/MAGMA_PIPELINE_STRUCTURE.md
  3. Various CSV/SVG output files will regenerate

Implementation Steps

Step 1: Verify GPU Resources

Before switching, confirm your remote GPU server (ermine) has sufficient VRAM:

ssh ermine nvidia-smi

You need at least 16GB VRAM, ideally 24GB+ for comfortable margins.

Step 2: Test 3B Model Locally

Run a quick test to verify the 3B model works:

from dnsmex.esm_wrapper import esm2_wrapper_of_size

# This will download the model on first run (~5.6GB)
esm = esm2_wrapper_of_size("3B", device="cuda")
score = esm.pseudo_perplexity("EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYWMSWVRQAPGKGLEWVANIKQDGSEKYYVDSVKGRFTISRDNAKNSLYLQMNSLRAEDTAVYYCAR")
print(f"Perplexity: {score}")

Step 3: Update Default Model Size

The cleanest approach is to update the default in dnsmex/cached_model_wrappers.py:

# dnsmex/cached_model_wrappers.py line 18
def __init__(self, model_size: str = "3B"):  # Changed from "650M"

# dnsmex/cached_model_wrappers.py line 98
def get_cached_esm_wrapper(model_size: str = "3B", use_remote: bool = True):

Then update dnsmex/remote_esm.py:

# Line 20
def __init__(self, model_size: str = "3B"):

# Line 362
def get_remote_esm_evaluator(model_size: str = "3B"):

Step 4: Clear Cached Results

The existing ESM scores are cached. You'll need to clear them:

# Remove cached ESM results to force recomputation
rm -rf _ignore/*ESM*

Step 5: Re-run Analyses

Re-run the full pipeline to generate new results with ESM2-3B:

python scripts/magma_unified_pipeline.py

Performance Considerations

Timing Estimates (based on shanehsazzadeh.ipynb data)

  • ESM2-650M: ~7.6s per 100 sequences on CUDA
  • ESM2-3B: Expect ~4x slower, so ~30s per 100 sequences

For the Whitehead dataset (~1000 sequences), expect:

  • 650M: ~76 seconds
  • 3B: ~5 minutes

Pseudo-perplexity Computation

The masked_logits method masks each position one at a time, making inference O(n) more expensive. For a 120-residue antibody:

  • 650M: ~120 forward passes per sequence
  • 3B: Same count, but each pass is slower

Alternative: Environment Variable

For flexibility, you could add an environment variable to select model size:

import os
DEFAULT_ESM_SIZE = os.environ.get("ESM_MODEL_SIZE", "3B")

Then set ESM_MODEL_SIZE=650M for quick testing or ESM_MODEL_SIZE=3B for production.

Minimal Changes Summary

For a quick switch to 3B, edit these 4 files:

  1. dnsmex/cached_model_wrappers.py - lines 18, 98
  2. dnsmex/remote_esm.py - lines 20, 362
  3. dnsmex/magma_helper.py - line 141
  4. dnsmex/kirby_helper.py - lines 155, 181

Total: ~6 line changes to switch the default model.

Output Files → Paper Figures/Tables Correspondence

After switching to ESM2-3B and re-running, these outputs need to be regenerated:

Tables

Paper Reference Source Script/Notebook Output File
Table 2 (Koenig correlations) notebooks/dasm_paper/koenig.ipynb _output/correlations_koenig.csv
Table 3 (Shane correlations) notebooks/dasm_paper/shanehsazzadeh.ipynb _output/correlations_shane.csv
Table 4 (MAGMA correlations) scripts/magma_unified_model_correlation_analysis.py data/whitehead/processed/magma_correlation_table.tex
Table 5 (Timing) scripts/timing_direct_gpu.py + scripts/make_timing_table.py data/whitehead/processed/timing_table.tex

Figures

Paper Figure Source Notebook/Script Output File → FIGURES_DIR
Fig 1 (NT process in LLMs) notebooks/dasm_paper/nt_process_in_llms.ipynb fig1.pdf (composite)
Fig 3 (Koenig DASM expression) notebooks/dasm_paper/koenig.ipynb koenig_dasm_expression.svgkoenig_expr_compare.pdf
Fig S2 (Koenig AbLang SHM colors) notebooks/dasm_paper/koenig.ipynb koenig_ablang_shm_colors.svg
Fig S3 (Koenig DASM light) notebooks/dasm_paper/koenig.ipynb koenig_dasm_expression_light.svg
Fig S4 (Koenig neighbor comparison) notebooks/dasm_paper/koenig.ipynb koenig_neighbor_comparison.svg
Fig S5 (Koenig heavy heatmap) notebooks/dasm_paper/koenig.ipynb koenig_heavy_heatmap.svg
Fig S6 (PCP predictions) Snakemake + notebooks/dasm_paper/perplexity.ipynb fig3.pdf
Fig S7 (Shanehsazzadeh) notebooks/dasm_paper/shanehsazzadeh.ipynb shanehsazzadeh.svg
Fig S8 (MAGMA scatter) scripts/magma_unified_model_correlation_analysis.py data/whitehead/processed/magma_unified_model_correlations.svg

MAGMA Pipeline (Critical for Table 4 & Fig S8)

The MAGMA outputs are generated by running these scripts in order:

python scripts/magma_unified_pipeline.py           # Scores with ESM (currently 650M)
python scripts/magma_unified_model_correlation_analysis.py  # Generates table + figure

Output files:

  • data/whitehead/processed/magma_correlation_table.csv → Table 4 data
  • data/whitehead/processed/magma_correlation_table.tex → LaTeX for Table 4
  • data/whitehead/processed/magma_unified_model_correlations.svg → Figure S8

Timing Pipeline

Timing benchmarks (Table 5) are generated by:

python scripts/timing_direct_gpu.py    # Run on GPU server
python scripts/make_timing_table.py    # Generate LaTeX table

Output files:

  • data/whitehead/processed/direct_timing_results_cuda_1.csv
  • data/whitehead/processed/direct_timing_results_cpu.csv
  • data/whitehead/processed/timing_table.tex

Note: ESM2-3B will be ~4x slower, so timing table values will change significantly.

Complete Re-generation Checklist

After switching to ESM2-3B:

  1. Clear ESM caches: rm -rf _ignore/*ESM*
  2. Re-run notebooks/dasm_paper/koenig.ipynb → Tables 2, Figs 3, S2-S5
  3. Re-run notebooks/dasm_paper/shanehsazzadeh.ipynb → Table 3, Fig S7
  4. Re-run scripts/magma_unified_pipeline.py → scores unified dataset
  5. Re-run scripts/magma_unified_model_correlation_analysis.py → Table 4, Fig S8
  6. Re-run scripts/timing_direct_gpu.py on GPU server → timing data
  7. Re-run scripts/make_timing_table.py → Table 5
  8. Convert SVGs to PDFs and copy to FIGURES_DIR

Sources

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions