diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6ba59732d7..26c225e337 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,7 +29,33 @@ repos: - id: ruff # 1. Attempt to automatically fix any lint issues. args: ["--fix"] + # Exclude check_copied_files destinations; they are verbatim copies + # of source files and must not be reformatted independently. + exclude: | + (?x)^( + bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv\.py| + bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv\.py| + bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv\.py| + bionemo-recipes/recipes/vllm_inference/esm2/(modeling_esm_te|convert|export|state)\.py| + bionemo-recipes/recipes/(esm2_native_te|llama3_native_te|esm2_peft_te)/collator\.py| + bionemo-recipes/recipes/llama3_native_te/modeling_llama_te\.py| + bionemo-recipes/models/(llama3|mixtral)/collator\.py| + bionemo-recipes/models/(amplify/src/amplify|llama3|mixtral)/state\.py| + bionemo-recipes/models/(llama3|mixtral)/tests/common/ + )$ - id: ruff-format + exclude: | + (?x)^( + bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv\.py| + bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv\.py| + bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv\.py| + bionemo-recipes/recipes/vllm_inference/esm2/(modeling_esm_te|convert|export|state)\.py| + bionemo-recipes/recipes/(esm2_native_te|llama3_native_te|esm2_peft_te)/collator\.py| + bionemo-recipes/recipes/llama3_native_te/modeling_llama_te\.py| + bionemo-recipes/models/(llama3|mixtral)/collator\.py| + bionemo-recipes/models/(amplify/src/amplify|llama3|mixtral)/state\.py| + bionemo-recipes/models/(llama3|mixtral)/tests/common/ + )$ - repo: https://github.com/executablebooks/mdformat rev: 0.7.22 # Use the latest stable version hooks: diff --git a/bionemo-recipes/models/esm2/README.md b/bionemo-recipes/models/esm2/README.md index 92d976e457..17f5a8b42a 100644 --- a/bionemo-recipes/models/esm2/README.md +++ b/bionemo-recipes/models/esm2/README.md @@ -68,6 +68,8 @@ Training recipes are available in the `bionemo-recipes/recipes/` directory: loop. - **[esm2_accelerate_te](../../recipes/esm2_accelerate_te/)** - Trains the model using HuggingFace [Accelerate](https://huggingface.co/docs/accelerate/index). +- **[vllm_inference/esm2](../../recipes/vllm_inference/esm2/)** - Demonstrates inference with + [vLLM](https://github.com/vllm-project/vllm). ## Converting Between Model Formats diff --git a/bionemo-recipes/models/esm2/convert.py b/bionemo-recipes/models/esm2/convert.py index 6b24fb71e4..da0f650883 100644 --- a/bionemo-recipes/models/esm2/convert.py +++ b/bionemo-recipes/models/esm2/convert.py @@ -27,18 +27,18 @@ mapping = { - "esm.encoder.layer.*.attention.output.dense.weight": "esm.encoder.layers.*.self_attention.proj.weight", - "esm.encoder.layer.*.attention.output.dense.bias": "esm.encoder.layers.*.self_attention.proj.bias", - "esm.encoder.layer.*.attention.LayerNorm.weight": "esm.encoder.layers.*.self_attention.layernorm_qkv.layer_norm_weight", - "esm.encoder.layer.*.attention.LayerNorm.bias": "esm.encoder.layers.*.self_attention.layernorm_qkv.layer_norm_bias", - "esm.encoder.layer.*.intermediate.dense.weight": "esm.encoder.layers.*.layernorm_mlp.fc1_weight", - "esm.encoder.layer.*.intermediate.dense.bias": "esm.encoder.layers.*.layernorm_mlp.fc1_bias", - "esm.encoder.layer.*.output.dense.weight": "esm.encoder.layers.*.layernorm_mlp.fc2_weight", - "esm.encoder.layer.*.output.dense.bias": "esm.encoder.layers.*.layernorm_mlp.fc2_bias", - "esm.encoder.layer.*.LayerNorm.weight": "esm.encoder.layers.*.layernorm_mlp.layer_norm_weight", - "esm.encoder.layer.*.LayerNorm.bias": "esm.encoder.layers.*.layernorm_mlp.layer_norm_bias", - "esm.encoder.emb_layer_norm_after.weight": "esm.encoder.emb_layer_norm_after.weight", - "esm.encoder.emb_layer_norm_after.bias": "esm.encoder.emb_layer_norm_after.bias", + "esm.encoder.layer.*.attention.output.dense.weight": "model.encoder.layers.*.self_attention.proj.weight", + "esm.encoder.layer.*.attention.output.dense.bias": "model.encoder.layers.*.self_attention.proj.bias", + "esm.encoder.layer.*.attention.LayerNorm.weight": "model.encoder.layers.*.self_attention.layernorm_qkv.layer_norm_weight", + "esm.encoder.layer.*.attention.LayerNorm.bias": "model.encoder.layers.*.self_attention.layernorm_qkv.layer_norm_bias", + "esm.encoder.layer.*.intermediate.dense.weight": "model.encoder.layers.*.layernorm_mlp.fc1_weight", + "esm.encoder.layer.*.intermediate.dense.bias": "model.encoder.layers.*.layernorm_mlp.fc1_bias", + "esm.encoder.layer.*.output.dense.weight": "model.encoder.layers.*.layernorm_mlp.fc2_weight", + "esm.encoder.layer.*.output.dense.bias": "model.encoder.layers.*.layernorm_mlp.fc2_bias", + "esm.encoder.layer.*.LayerNorm.weight": "model.encoder.layers.*.layernorm_mlp.layer_norm_weight", + "esm.encoder.layer.*.LayerNorm.bias": "model.encoder.layers.*.layernorm_mlp.layer_norm_bias", + "esm.encoder.emb_layer_norm_after.weight": "model.encoder.emb_layer_norm_after.weight", + "esm.encoder.emb_layer_norm_after.bias": "model.encoder.emb_layer_norm_after.bias", "lm_head.dense.weight": "lm_head.dense.weight", "lm_head.dense.bias": "lm_head.dense.bias", "lm_head.layer_norm.weight": "lm_head.decoder.layer_norm_weight", @@ -135,7 +135,7 @@ def convert_esm_te_to_hf(model_te: nn.Module, **config_kwargs) -> nn.Module: "esm.encoder.layer.*.attention.self.key.weight", "esm.encoder.layer.*.attention.self.value.weight", ), - target_key="esm.encoder.layers.*.self_attention.layernorm_qkv.weight", + target_key="model.encoder.layers.*.self_attention.layernorm_qkv.weight", ) def _pack_qkv_weight(ctx: state.TransformCTX, query, key, value): """Pack separate Q, K, V weight tensors into a single interleaved QKV weight tensor.""" @@ -157,7 +157,7 @@ def _pack_qkv_weight(ctx: state.TransformCTX, query, key, value): "esm.encoder.layer.*.attention.self.key.bias", "esm.encoder.layer.*.attention.self.value.bias", ), - target_key="esm.encoder.layers.*.self_attention.layernorm_qkv.bias", + target_key="model.encoder.layers.*.self_attention.layernorm_qkv.bias", ) def _pack_qkv_bias(ctx: state.TransformCTX, query, key, value): """Pack separate Q, K, V bias tensors into a single interleaved QKV bias tensor.""" @@ -174,7 +174,7 @@ def _pack_qkv_bias(ctx: state.TransformCTX, query, key, value): @state.state_transform( - source_key="esm.encoder.layers.*.self_attention.layernorm_qkv.weight", + source_key="model.encoder.layers.*.self_attention.layernorm_qkv.weight", target_key=( "esm.encoder.layer.*.attention.self.query.weight", "esm.encoder.layer.*.attention.self.key.weight", @@ -203,7 +203,7 @@ def _unpack_qkv_weight(ctx: state.TransformCTX, qkv_weight): @state.state_transform( - source_key="esm.encoder.layers.*.self_attention.layernorm_qkv.bias", + source_key="model.encoder.layers.*.self_attention.layernorm_qkv.bias", target_key=( "esm.encoder.layer.*.attention.self.query.bias", "esm.encoder.layer.*.attention.self.key.bias", @@ -248,7 +248,7 @@ def _pad_weights(ctx: state.TransformCTX, source_embed): _pad_embeddings = state.state_transform( source_key="esm.embeddings.word_embeddings.weight", - target_key="esm.embeddings.word_embeddings.weight", + target_key="model.embeddings.word_embeddings.weight", )(_pad_weights) _pad_decoder_weights = state.state_transform( @@ -257,7 +257,7 @@ def _pad_weights(ctx: state.TransformCTX, source_embed): )(_pad_weights) _unpad_embeddings = state.state_transform( - source_key="esm.embeddings.word_embeddings.weight", + source_key="model.embeddings.word_embeddings.weight", target_key="esm.embeddings.word_embeddings.weight", )(_unpad_weights) diff --git a/bionemo-recipes/models/esm2/export.py b/bionemo-recipes/models/esm2/export.py index 12f13e45a2..748e467843 100644 --- a/bionemo-recipes/models/esm2/export.py +++ b/bionemo-recipes/models/esm2/export.py @@ -71,7 +71,11 @@ def export_hf_checkpoint(tag: str, export_path: Path): model_hf_masked_lm = AutoModelForMaskedLM.from_pretrained(f"facebook/{tag}") model_hf = AutoModel.from_pretrained(f"facebook/{tag}") model_hf_masked_lm.esm.pooler = model_hf.pooler - model_te = convert_esm_hf_to_te(model_hf_masked_lm) + + # Export without vocab padding so the checkpoint stores embeddings at the real + # vocab_size. This avoids shape-mismatch errors in vLLM's VocabParallelEmbedding, + # which expects vocab_size-shaped weights. + model_te = convert_esm_hf_to_te(model_hf_masked_lm, padded_vocab_size=None) model_te.save_pretrained(export_path / tag) tokenizer = AutoTokenizer.from_pretrained("esm_fast_tokenizer") # Use our PreTrainedTokenizerFast implementation. diff --git a/bionemo-recipes/models/esm2/modeling_esm_te.py b/bionemo-recipes/models/esm2/modeling_esm_te.py index a9d03fad5e..298fc70f6b 100644 --- a/bionemo-recipes/models/esm2/modeling_esm_te.py +++ b/bionemo-recipes/models/esm2/modeling_esm_te.py @@ -73,6 +73,7 @@ def __init__( max_seq_length: Optional[int] = None, padded_vocab_size: Optional[int] = 64, attn_mask_type: str = "padding", + add_pooling_layer: bool = False, layer_precision: list[str | None] | None = None, **kwargs, ): @@ -103,6 +104,9 @@ def __init__( padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults to vocab_size. Must be greater than or equal to vocab_size. attn_mask_type: The type of attention mask to use. + add_pooling_layer: Whether the base model should include a pooling layer. + Defaults to ``False`` because exported checkpoints do not contain pooler + weights. Set to ``True`` only if you have a checkpoint with pooler weights. layer_precision: Per-layer quantization precision, a list of length ``num_hidden_layers`` where each element is ``"fp8"``, ``"fp4"``, or ``None`` (BF16 fallback). ``None`` (the default) means no quantization is configured. @@ -117,6 +121,7 @@ def __init__( self.micro_batch_size = micro_batch_size self.max_seq_length = max_seq_length self.attn_mask_type = attn_mask_type + self.add_pooling_layer = add_pooling_layer self.layer_precision = layer_precision # Set padded_vocab_size with default fallback to vocab_size @@ -289,7 +294,7 @@ class NVEsmPreTrainedModel(EsmPreTrainedModel): """An abstract class to handle weights initialization and pretrained model loading.""" config_class = NVEsmConfig - base_model_prefix = "esm" + base_model_prefix = "model" supports_gradient_checkpointing = False accepts_loss_kwargs = False _no_split_modules = ( @@ -305,11 +310,11 @@ def init_empty_weights(self): if hasattr(module, "reset_parameters"): module.reset_parameters() - # The esm.embeddings layer is the only non-TE layer in this model we need to deal with. We use + # The embeddings layer is the only non-TE layer in this model we need to deal with. We use # `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard - # deviation. - self.esm.embeddings.word_embeddings.to_empty(device="cuda") - self.esm.embeddings.apply(self._init_weights) + # deviation. self.base_model resolves to self.model for wrapper classes or self for NVEsmModel. + self.base_model.embeddings.word_embeddings.to_empty(device="cuda") + self.base_model.embeddings.apply(self._init_weights) # Meta-device init seems to break weight tying, so we re-tie the weights here. self.tie_weights() @@ -334,14 +339,16 @@ def _init_weights(self, module): super()._init_weights(module) def state_dict(self, *args, **kwargs): - """Override state_dict to filter out TransformerEngine's _extra_state keys. + """Override state_dict to filter out non-loadable keys. - TransformerEngine layers add _extra_state attributes that are not compatible with HuggingFace v5 model loading. - These are filtered out to ensure checkpoints can be loaded with from_pretrained(). + Filters out: + - ``_extra_state`` keys: TransformerEngine-specific, not loadable by HuggingFace v5. + - ``.inv_freq`` buffers: Computed at init time by RotaryPositionEmbedding, not needed + in the checkpoint and not loadable by vLLM's AutoWeightsLoader (which only iterates + over ``named_parameters``, not ``named_buffers``). """ state_dict = super().state_dict(*args, **kwargs) - # Filter out _extra_state keys which are TransformerEngine-specific and not loadable - return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} + return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state") and not k.endswith(".inv_freq")} class NVEsmModel(NVEsmPreTrainedModel): @@ -350,16 +357,20 @@ class NVEsmModel(NVEsmPreTrainedModel): This model uses NVDIA's TransformerEngine to optimize attention layer training and inference. """ - def __init__(self, config: NVEsmConfig, add_pooling_layer: bool = True): + def __init__(self, config: NVEsmConfig, add_pooling_layer: Optional[bool] = None): """Initialize a NVEsmModel. Args: config (NVEsmConfig): The configuration of the model. - add_pooling_layer (bool): Whether to add a pooling layer. + add_pooling_layer (bool): Whether to add a pooling layer. If ``None``, + reads ``config.add_pooling_layer`` (defaults to ``True``). """ super().__init__(config) self.config = config + if add_pooling_layer is None: + add_pooling_layer = getattr(config, "add_pooling_layer", True) + # Ensure pad_token_id is set properly, defaulting to 0 if not specified if not hasattr(config, "pad_token_id") or config.pad_token_id is None: config.pad_token_id = 0 @@ -449,7 +460,9 @@ def forward( class NVEsmForMaskedLM(NVEsmPreTrainedModel): """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling.""" - _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.decoder.weight": "esm.embeddings.word_embeddings.weight"} + _tied_weights_keys: ClassVar[dict[str, str]] = { + "lm_head.decoder.weight": "model.embeddings.word_embeddings.weight" + } _do_not_quantize = ("lm_head.dense", "lm_head.decoder") # Flag for testing that these layers are not quantized. def __init__(self, config: NVEsmConfig): @@ -466,7 +479,7 @@ def __init__(self, config: NVEsmConfig): "bi-directional self-attention." ) - self.esm = NVEsmModel(config, add_pooling_layer=False) + self.model = NVEsmModel(config, add_pooling_layer=False) self.lm_head = NVEsmLMHead(config) self.post_init() @@ -501,7 +514,7 @@ def forward( Returns: MaskedLMOutput: The output of the model. """ - outputs = self.esm( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -719,7 +732,7 @@ def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels - self.esm = NVEsmModel(config, add_pooling_layer=False) + self.model = NVEsmModel(config, add_pooling_layer=False) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = transformer_engine.pytorch.Linear( config.hidden_size, @@ -745,7 +758,7 @@ def forward( labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ - outputs = self.esm( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, diff --git a/bionemo-recipes/models/esm2/tests/test_cp_bshd.py b/bionemo-recipes/models/esm2/tests/test_cp_bshd.py index 2af776b880..5e9d7f96b7 100644 --- a/bionemo-recipes/models/esm2/tests/test_cp_bshd.py +++ b/bionemo-recipes/models/esm2/tests/test_cp_bshd.py @@ -209,8 +209,8 @@ def test_context_parallel_equivalence_2process(): # Sample gradients from a few layers for comparison sample_layers = [ - model.esm.encoder.layers[0].self_attention.core_attention, - model.esm.encoder.layers[0].self_attention.layernorm_qkv, + model.model.encoder.layers[0].self_attention.core_attention, + model.model.encoder.layers[0].self_attention.layernorm_qkv, ] # Now grab the gradients from the sample layers @@ -262,7 +262,7 @@ def test_context_parallel_equivalence_2process(): cp_world_size = torch.distributed.get_world_size(group=cp_group) # Set up context parallelism for each layer - for i, transformer_layer in enumerate(model.module.esm.encoder.layers): + for i, transformer_layer in enumerate(model.module.model.encoder.layers): transformer_layer.set_context_parallel_group( cp_group, torch.distributed.get_process_group_ranks(device_mesh["cp"].get_group()), torch.cuda.Stream() ) @@ -347,8 +347,8 @@ def test_context_parallel_equivalence_2process(): # Capture gradients from the same layers in the CP model # Note: DDP wraps the model with 'module.' prefix sample_layers_cp = [ - model.module.esm.encoder.layers[0].self_attention.core_attention, - model.module.esm.encoder.layers[0].self_attention.layernorm_qkv, + model.module.model.encoder.layers[0].self_attention.core_attention, + model.module.model.encoder.layers[0].self_attention.layernorm_qkv, ] gradients_cp = {} diff --git a/bionemo-recipes/models/esm2/tests/test_cp_thd.py b/bionemo-recipes/models/esm2/tests/test_cp_thd.py index c17618ba98..98dd62b742 100644 --- a/bionemo-recipes/models/esm2/tests/test_cp_thd.py +++ b/bionemo-recipes/models/esm2/tests/test_cp_thd.py @@ -200,8 +200,8 @@ def test_context_parallel_equivalence_2process(): # Sample gradients from a few layers for comparison sample_layers = [ - model.esm.encoder.layers[0].self_attention.core_attention, - model.esm.encoder.layers[0].self_attention.layernorm_qkv, + model.model.encoder.layers[0].self_attention.core_attention, + model.model.encoder.layers[0].self_attention.layernorm_qkv, ] # Now grab the gradients from the sample layers @@ -253,7 +253,7 @@ def test_context_parallel_equivalence_2process(): cp_world_size = torch.distributed.get_world_size(group=cp_group) # Set up context parallelism for each layer - for i, transformer_layer in enumerate(model.module.esm.encoder.layers): + for i, transformer_layer in enumerate(model.module.model.encoder.layers): transformer_layer.set_context_parallel_group( cp_group, torch.distributed.get_process_group_ranks(device_mesh["cp"].get_group()), torch.cuda.Stream() ) @@ -344,8 +344,8 @@ def test_context_parallel_equivalence_2process(): # Capture gradients from the same layers in the CP model # Note: DDP wraps the model with 'module.' prefix sample_layers_cp = [ - model.module.esm.encoder.layers[0].self_attention.core_attention, - model.module.esm.encoder.layers[0].self_attention.layernorm_qkv, + model.module.model.encoder.layers[0].self_attention.core_attention, + model.module.model.encoder.layers[0].self_attention.layernorm_qkv, ] gradients_cp = {} diff --git a/bionemo-recipes/models/esm2/tests/test_distributed_fp8.py b/bionemo-recipes/models/esm2/tests/test_distributed_fp8.py index 04246a0d81..aae4e33ae0 100644 --- a/bionemo-recipes/models/esm2/tests/test_distributed_fp8.py +++ b/bionemo-recipes/models/esm2/tests/test_distributed_fp8.py @@ -168,7 +168,7 @@ def is_main_process(self) -> bool: model = NVEsmForMaskedLM(config) if args.strategy is Strategy.FSDP2: - for layer in model.esm.encoder.layers: + for layer in model.model.encoder.layers: fully_shard(layer, mesh=device_mesh["dp"]) fully_shard(model, mesh=device_mesh["dp"]) model.to(device) @@ -199,7 +199,7 @@ def is_main_process(self) -> bool: ) # Attach FP8 recipes to the encoder (layer precision is already on config). - encoder = model.module.esm.encoder if args.strategy is Strategy.DDP else model.esm.encoder + encoder = model.module.model.encoder if args.strategy is Strategy.DDP else model.model.encoder encoder.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=None) model.train() diff --git a/bionemo-recipes/models/esm2/tests/test_distributed_strategies.py b/bionemo-recipes/models/esm2/tests/test_distributed_strategies.py index 8347f0f1ef..ab96c6590e 100644 --- a/bionemo-recipes/models/esm2/tests/test_distributed_strategies.py +++ b/bionemo-recipes/models/esm2/tests/test_distributed_strategies.py @@ -193,7 +193,7 @@ def run_forward_backward(use_te: bool, strategy: Strategy, input_data: dict, dis revision="c731040f", ) model = NVEsmForMaskedLM(config) - transformer_layers = model.esm.encoder.layers + transformer_layers = model.model.encoder.layers else: model = AutoModelForMaskedLM.from_pretrained( "facebook/esm2_t6_8M_UR50D", diff --git a/bionemo-recipes/models/esm2/tests/test_modeling_esm_te.py b/bionemo-recipes/models/esm2/tests/test_modeling_esm_te.py index d2a8fe0e1e..e9ca8c5a2f 100644 --- a/bionemo-recipes/models/esm2/tests/test_modeling_esm_te.py +++ b/bionemo-recipes/models/esm2/tests/test_modeling_esm_te.py @@ -81,7 +81,7 @@ def get_upstream_model_class(self) -> Type[PreTrainedModel]: def get_layer_path(self, model: PreTrainedModel) -> List[nn.Module]: """Return the list of transformer layers.""" - return list(model.esm.encoder.layers) # type: ignore + return list(model.model.encoder.layers) # type: ignore def get_reference_model_no_weights(self) -> PreTrainedModel: """For checkpoint conversion tests to pass, we need to remove the unused contact head.""" @@ -195,7 +195,7 @@ def test_convert_state_dict_explicit_check(self): # Check packed QKV weights for i in range(model_hf.config.num_hidden_layers): - k = f"esm.encoder.layers.{i}.self_attention.layernorm_qkv.weight" + k = f"model.encoder.layers.{i}.self_attention.layernorm_qkv.weight" v = [ f"esm.encoder.layer.{i}.attention.self.query.weight", f"esm.encoder.layer.{i}.attention.self.key.weight", @@ -217,7 +217,7 @@ def test_convert_state_dict_explicit_check(self): # Check packed QKV biases for i in range(model_hf.config.num_hidden_layers): - k = f"esm.encoder.layers.{i}.self_attention.layernorm_qkv.bias" + k = f"model.encoder.layers.{i}.self_attention.layernorm_qkv.bias" v = [ f"esm.encoder.layer.{i}.attention.self.query.bias", f"esm.encoder.layer.{i}.attention.self.key.bias", @@ -243,7 +243,7 @@ def test_convert_state_dict_explicit_check(self): torch.testing.assert_close( _pad_weights(ctx_mock, model_hf.state_dict()["esm.embeddings.word_embeddings.weight"]), - model_te.state_dict()["esm.embeddings.word_embeddings.weight"], + model_te.state_dict()["model.embeddings.word_embeddings.weight"], ) torch.testing.assert_close( _pad_weights(ctx_mock, model_hf.state_dict()["lm_head.decoder.weight"]), @@ -254,7 +254,7 @@ def test_convert_state_dict_explicit_check(self): model_te.state_dict()["lm_head.decoder.bias"], ) - te_state_dict_keys.remove("esm.embeddings.word_embeddings.weight") + te_state_dict_keys.remove("model.embeddings.word_embeddings.weight") te_state_dict_keys.remove("lm_head.decoder.weight") te_state_dict_keys.remove("lm_head.decoder.bias") @@ -267,7 +267,7 @@ def test_convert_state_dict_explicit_check(self): ) assert ( - model_te.state_dict()["esm.embeddings.word_embeddings.weight"].data_ptr() + model_te.state_dict()["model.embeddings.word_embeddings.weight"].data_ptr() == model_te.state_dict()["lm_head.decoder.weight"].data_ptr() ) diff --git a/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py b/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py index a9d03fad5e..298fc70f6b 100644 --- a/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py +++ b/bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py @@ -73,6 +73,7 @@ def __init__( max_seq_length: Optional[int] = None, padded_vocab_size: Optional[int] = 64, attn_mask_type: str = "padding", + add_pooling_layer: bool = False, layer_precision: list[str | None] | None = None, **kwargs, ): @@ -103,6 +104,9 @@ def __init__( padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults to vocab_size. Must be greater than or equal to vocab_size. attn_mask_type: The type of attention mask to use. + add_pooling_layer: Whether the base model should include a pooling layer. + Defaults to ``False`` because exported checkpoints do not contain pooler + weights. Set to ``True`` only if you have a checkpoint with pooler weights. layer_precision: Per-layer quantization precision, a list of length ``num_hidden_layers`` where each element is ``"fp8"``, ``"fp4"``, or ``None`` (BF16 fallback). ``None`` (the default) means no quantization is configured. @@ -117,6 +121,7 @@ def __init__( self.micro_batch_size = micro_batch_size self.max_seq_length = max_seq_length self.attn_mask_type = attn_mask_type + self.add_pooling_layer = add_pooling_layer self.layer_precision = layer_precision # Set padded_vocab_size with default fallback to vocab_size @@ -289,7 +294,7 @@ class NVEsmPreTrainedModel(EsmPreTrainedModel): """An abstract class to handle weights initialization and pretrained model loading.""" config_class = NVEsmConfig - base_model_prefix = "esm" + base_model_prefix = "model" supports_gradient_checkpointing = False accepts_loss_kwargs = False _no_split_modules = ( @@ -305,11 +310,11 @@ def init_empty_weights(self): if hasattr(module, "reset_parameters"): module.reset_parameters() - # The esm.embeddings layer is the only non-TE layer in this model we need to deal with. We use + # The embeddings layer is the only non-TE layer in this model we need to deal with. We use # `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard - # deviation. - self.esm.embeddings.word_embeddings.to_empty(device="cuda") - self.esm.embeddings.apply(self._init_weights) + # deviation. self.base_model resolves to self.model for wrapper classes or self for NVEsmModel. + self.base_model.embeddings.word_embeddings.to_empty(device="cuda") + self.base_model.embeddings.apply(self._init_weights) # Meta-device init seems to break weight tying, so we re-tie the weights here. self.tie_weights() @@ -334,14 +339,16 @@ def _init_weights(self, module): super()._init_weights(module) def state_dict(self, *args, **kwargs): - """Override state_dict to filter out TransformerEngine's _extra_state keys. + """Override state_dict to filter out non-loadable keys. - TransformerEngine layers add _extra_state attributes that are not compatible with HuggingFace v5 model loading. - These are filtered out to ensure checkpoints can be loaded with from_pretrained(). + Filters out: + - ``_extra_state`` keys: TransformerEngine-specific, not loadable by HuggingFace v5. + - ``.inv_freq`` buffers: Computed at init time by RotaryPositionEmbedding, not needed + in the checkpoint and not loadable by vLLM's AutoWeightsLoader (which only iterates + over ``named_parameters``, not ``named_buffers``). """ state_dict = super().state_dict(*args, **kwargs) - # Filter out _extra_state keys which are TransformerEngine-specific and not loadable - return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} + return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state") and not k.endswith(".inv_freq")} class NVEsmModel(NVEsmPreTrainedModel): @@ -350,16 +357,20 @@ class NVEsmModel(NVEsmPreTrainedModel): This model uses NVDIA's TransformerEngine to optimize attention layer training and inference. """ - def __init__(self, config: NVEsmConfig, add_pooling_layer: bool = True): + def __init__(self, config: NVEsmConfig, add_pooling_layer: Optional[bool] = None): """Initialize a NVEsmModel. Args: config (NVEsmConfig): The configuration of the model. - add_pooling_layer (bool): Whether to add a pooling layer. + add_pooling_layer (bool): Whether to add a pooling layer. If ``None``, + reads ``config.add_pooling_layer`` (defaults to ``True``). """ super().__init__(config) self.config = config + if add_pooling_layer is None: + add_pooling_layer = getattr(config, "add_pooling_layer", True) + # Ensure pad_token_id is set properly, defaulting to 0 if not specified if not hasattr(config, "pad_token_id") or config.pad_token_id is None: config.pad_token_id = 0 @@ -449,7 +460,9 @@ def forward( class NVEsmForMaskedLM(NVEsmPreTrainedModel): """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling.""" - _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.decoder.weight": "esm.embeddings.word_embeddings.weight"} + _tied_weights_keys: ClassVar[dict[str, str]] = { + "lm_head.decoder.weight": "model.embeddings.word_embeddings.weight" + } _do_not_quantize = ("lm_head.dense", "lm_head.decoder") # Flag for testing that these layers are not quantized. def __init__(self, config: NVEsmConfig): @@ -466,7 +479,7 @@ def __init__(self, config: NVEsmConfig): "bi-directional self-attention." ) - self.esm = NVEsmModel(config, add_pooling_layer=False) + self.model = NVEsmModel(config, add_pooling_layer=False) self.lm_head = NVEsmLMHead(config) self.post_init() @@ -501,7 +514,7 @@ def forward( Returns: MaskedLMOutput: The output of the model. """ - outputs = self.esm( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -719,7 +732,7 @@ def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels - self.esm = NVEsmModel(config, add_pooling_layer=False) + self.model = NVEsmModel(config, add_pooling_layer=False) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = transformer_engine.pytorch.Linear( config.hidden_size, @@ -745,7 +758,7 @@ def forward( labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ - outputs = self.esm( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, diff --git a/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py b/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py index a9d03fad5e..298fc70f6b 100644 --- a/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py +++ b/bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py @@ -73,6 +73,7 @@ def __init__( max_seq_length: Optional[int] = None, padded_vocab_size: Optional[int] = 64, attn_mask_type: str = "padding", + add_pooling_layer: bool = False, layer_precision: list[str | None] | None = None, **kwargs, ): @@ -103,6 +104,9 @@ def __init__( padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults to vocab_size. Must be greater than or equal to vocab_size. attn_mask_type: The type of attention mask to use. + add_pooling_layer: Whether the base model should include a pooling layer. + Defaults to ``False`` because exported checkpoints do not contain pooler + weights. Set to ``True`` only if you have a checkpoint with pooler weights. layer_precision: Per-layer quantization precision, a list of length ``num_hidden_layers`` where each element is ``"fp8"``, ``"fp4"``, or ``None`` (BF16 fallback). ``None`` (the default) means no quantization is configured. @@ -117,6 +121,7 @@ def __init__( self.micro_batch_size = micro_batch_size self.max_seq_length = max_seq_length self.attn_mask_type = attn_mask_type + self.add_pooling_layer = add_pooling_layer self.layer_precision = layer_precision # Set padded_vocab_size with default fallback to vocab_size @@ -289,7 +294,7 @@ class NVEsmPreTrainedModel(EsmPreTrainedModel): """An abstract class to handle weights initialization and pretrained model loading.""" config_class = NVEsmConfig - base_model_prefix = "esm" + base_model_prefix = "model" supports_gradient_checkpointing = False accepts_loss_kwargs = False _no_split_modules = ( @@ -305,11 +310,11 @@ def init_empty_weights(self): if hasattr(module, "reset_parameters"): module.reset_parameters() - # The esm.embeddings layer is the only non-TE layer in this model we need to deal with. We use + # The embeddings layer is the only non-TE layer in this model we need to deal with. We use # `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard - # deviation. - self.esm.embeddings.word_embeddings.to_empty(device="cuda") - self.esm.embeddings.apply(self._init_weights) + # deviation. self.base_model resolves to self.model for wrapper classes or self for NVEsmModel. + self.base_model.embeddings.word_embeddings.to_empty(device="cuda") + self.base_model.embeddings.apply(self._init_weights) # Meta-device init seems to break weight tying, so we re-tie the weights here. self.tie_weights() @@ -334,14 +339,16 @@ def _init_weights(self, module): super()._init_weights(module) def state_dict(self, *args, **kwargs): - """Override state_dict to filter out TransformerEngine's _extra_state keys. + """Override state_dict to filter out non-loadable keys. - TransformerEngine layers add _extra_state attributes that are not compatible with HuggingFace v5 model loading. - These are filtered out to ensure checkpoints can be loaded with from_pretrained(). + Filters out: + - ``_extra_state`` keys: TransformerEngine-specific, not loadable by HuggingFace v5. + - ``.inv_freq`` buffers: Computed at init time by RotaryPositionEmbedding, not needed + in the checkpoint and not loadable by vLLM's AutoWeightsLoader (which only iterates + over ``named_parameters``, not ``named_buffers``). """ state_dict = super().state_dict(*args, **kwargs) - # Filter out _extra_state keys which are TransformerEngine-specific and not loadable - return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} + return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state") and not k.endswith(".inv_freq")} class NVEsmModel(NVEsmPreTrainedModel): @@ -350,16 +357,20 @@ class NVEsmModel(NVEsmPreTrainedModel): This model uses NVDIA's TransformerEngine to optimize attention layer training and inference. """ - def __init__(self, config: NVEsmConfig, add_pooling_layer: bool = True): + def __init__(self, config: NVEsmConfig, add_pooling_layer: Optional[bool] = None): """Initialize a NVEsmModel. Args: config (NVEsmConfig): The configuration of the model. - add_pooling_layer (bool): Whether to add a pooling layer. + add_pooling_layer (bool): Whether to add a pooling layer. If ``None``, + reads ``config.add_pooling_layer`` (defaults to ``True``). """ super().__init__(config) self.config = config + if add_pooling_layer is None: + add_pooling_layer = getattr(config, "add_pooling_layer", True) + # Ensure pad_token_id is set properly, defaulting to 0 if not specified if not hasattr(config, "pad_token_id") or config.pad_token_id is None: config.pad_token_id = 0 @@ -449,7 +460,9 @@ def forward( class NVEsmForMaskedLM(NVEsmPreTrainedModel): """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling.""" - _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.decoder.weight": "esm.embeddings.word_embeddings.weight"} + _tied_weights_keys: ClassVar[dict[str, str]] = { + "lm_head.decoder.weight": "model.embeddings.word_embeddings.weight" + } _do_not_quantize = ("lm_head.dense", "lm_head.decoder") # Flag for testing that these layers are not quantized. def __init__(self, config: NVEsmConfig): @@ -466,7 +479,7 @@ def __init__(self, config: NVEsmConfig): "bi-directional self-attention." ) - self.esm = NVEsmModel(config, add_pooling_layer=False) + self.model = NVEsmModel(config, add_pooling_layer=False) self.lm_head = NVEsmLMHead(config) self.post_init() @@ -501,7 +514,7 @@ def forward( Returns: MaskedLMOutput: The output of the model. """ - outputs = self.esm( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -719,7 +732,7 @@ def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels - self.esm = NVEsmModel(config, add_pooling_layer=False) + self.model = NVEsmModel(config, add_pooling_layer=False) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = transformer_engine.pytorch.Linear( config.hidden_size, @@ -745,7 +758,7 @@ def forward( labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ - outputs = self.esm( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, diff --git a/bionemo-recipes/recipes/esm2_native_te/tests/test_stop_and_go.py b/bionemo-recipes/recipes/esm2_native_te/tests/test_stop_and_go.py index a8b8afc6af..bbb9fa01d9 100644 --- a/bionemo-recipes/recipes/esm2_native_te/tests/test_stop_and_go.py +++ b/bionemo-recipes/recipes/esm2_native_te/tests/test_stop_and_go.py @@ -74,8 +74,9 @@ def test_stop_and_go_checkpointing_and_dataloader_restoration_single_gpu(tmp_pat # The huggingface model has a contact head that we don't use in masked language pre-training, so we delete it to # avoid errors with unused parameters. + base = model.model if hasattr(model, "model") else model.esm try: - del model.esm.contact_head + del base.contact_head except AttributeError: pass @@ -156,8 +157,9 @@ def test_stop_and_go_checkpointing_and_dataloader_restoration_single_gpu(tmp_pat config = AutoConfig.from_pretrained("example_8m_checkpoint", trust_remote_code=True, dtype=torch.bfloat16) resumed_model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True) + resumed_base = resumed_model.model if hasattr(resumed_model, "model") else resumed_model.esm try: - del resumed_model.esm.contact_head + del resumed_base.contact_head except AttributeError: pass diff --git a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py index b2160a31e0..ee6a026c27 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_ddp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_ddp.py @@ -113,8 +113,9 @@ def main(args: DictConfig) -> float | None: # The huggingface model has a contact head that we don't use in masked language pre-training, so we delete it to # avoid errors with unused parameters. + base = model.model if hasattr(model, "model") else model.esm try: - del model.esm.contact_head + del base.contact_head except AttributeError: pass diff --git a/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py b/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py index e91c34952e..b179a53eb1 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py @@ -134,8 +134,9 @@ def main(args: DictConfig) -> float | None: process_group=group_fsdp_cp, ) + base = model.module.model if hasattr(model.module, "model") else model.module.esm if args.cp_size > 1: - for i, transformer_layer in enumerate(model.module.esm.encoder.layers): + for i, transformer_layer in enumerate(base.encoder.layers): logger.debug(f"Rank {dist_config.rank}: Setting CP group for layer {i}") transformer_layer.set_context_parallel_group( device_mesh["cp"].get_group(), diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py index 6dd6ccaba8..e8af6268de 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py @@ -119,7 +119,8 @@ def main(args: DictConfig) -> float | None: logger.info("Initialized Model:\n%s", model) # We call the transformer stack "layers" in our TE models, but it's called "layer" in the original ESM-2 models. - transformer_stack = model.esm.encoder.layers if hasattr(model.esm.encoder, "layers") else model.esm.encoder.layer + base = model.model if hasattr(model, "model") else model.esm + transformer_stack = base.encoder.layers if hasattr(base.encoder, "layers") else base.encoder.layer if args.use_fp32_master_weights: mp_policy = MixedPrecisionPolicy( diff --git a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py index 6c00b54d0a..872c8c9c09 100644 --- a/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py +++ b/bionemo-recipes/recipes/esm2_native_te/train_fsdp2_cp.py @@ -129,8 +129,10 @@ def main(args: DictConfig) -> float | None: logger.info("Initialized Model:\n%s", model) + # TE models use `model.model`, facebook HF models use `model.esm`. + base = model.model if hasattr(model, "model") else model.esm # We call the transformer stack "layers" in our TE models, but it's called "layer" in the original ESM-2 models. - transformer_stack = model.esm.encoder.layers if hasattr(model.esm.encoder, "layers") else model.esm.encoder.layer + transformer_stack = base.encoder.layers if hasattr(base.encoder, "layers") else base.encoder.layer # Fully shard takes in a DeviceMesh object, which is a 2D mesh of dimensions (CP_dimension, DP_dimension). # FSDP2 will shard the model across the DP (dim=1) dimension and then duplicate across the CP (dim=0) dimension. for layer in transformer_stack: diff --git a/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py b/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py index a9d03fad5e..298fc70f6b 100644 --- a/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py +++ b/bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py @@ -73,6 +73,7 @@ def __init__( max_seq_length: Optional[int] = None, padded_vocab_size: Optional[int] = 64, attn_mask_type: str = "padding", + add_pooling_layer: bool = False, layer_precision: list[str | None] | None = None, **kwargs, ): @@ -103,6 +104,9 @@ def __init__( padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults to vocab_size. Must be greater than or equal to vocab_size. attn_mask_type: The type of attention mask to use. + add_pooling_layer: Whether the base model should include a pooling layer. + Defaults to ``False`` because exported checkpoints do not contain pooler + weights. Set to ``True`` only if you have a checkpoint with pooler weights. layer_precision: Per-layer quantization precision, a list of length ``num_hidden_layers`` where each element is ``"fp8"``, ``"fp4"``, or ``None`` (BF16 fallback). ``None`` (the default) means no quantization is configured. @@ -117,6 +121,7 @@ def __init__( self.micro_batch_size = micro_batch_size self.max_seq_length = max_seq_length self.attn_mask_type = attn_mask_type + self.add_pooling_layer = add_pooling_layer self.layer_precision = layer_precision # Set padded_vocab_size with default fallback to vocab_size @@ -289,7 +294,7 @@ class NVEsmPreTrainedModel(EsmPreTrainedModel): """An abstract class to handle weights initialization and pretrained model loading.""" config_class = NVEsmConfig - base_model_prefix = "esm" + base_model_prefix = "model" supports_gradient_checkpointing = False accepts_loss_kwargs = False _no_split_modules = ( @@ -305,11 +310,11 @@ def init_empty_weights(self): if hasattr(module, "reset_parameters"): module.reset_parameters() - # The esm.embeddings layer is the only non-TE layer in this model we need to deal with. We use + # The embeddings layer is the only non-TE layer in this model we need to deal with. We use # `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard - # deviation. - self.esm.embeddings.word_embeddings.to_empty(device="cuda") - self.esm.embeddings.apply(self._init_weights) + # deviation. self.base_model resolves to self.model for wrapper classes or self for NVEsmModel. + self.base_model.embeddings.word_embeddings.to_empty(device="cuda") + self.base_model.embeddings.apply(self._init_weights) # Meta-device init seems to break weight tying, so we re-tie the weights here. self.tie_weights() @@ -334,14 +339,16 @@ def _init_weights(self, module): super()._init_weights(module) def state_dict(self, *args, **kwargs): - """Override state_dict to filter out TransformerEngine's _extra_state keys. + """Override state_dict to filter out non-loadable keys. - TransformerEngine layers add _extra_state attributes that are not compatible with HuggingFace v5 model loading. - These are filtered out to ensure checkpoints can be loaded with from_pretrained(). + Filters out: + - ``_extra_state`` keys: TransformerEngine-specific, not loadable by HuggingFace v5. + - ``.inv_freq`` buffers: Computed at init time by RotaryPositionEmbedding, not needed + in the checkpoint and not loadable by vLLM's AutoWeightsLoader (which only iterates + over ``named_parameters``, not ``named_buffers``). """ state_dict = super().state_dict(*args, **kwargs) - # Filter out _extra_state keys which are TransformerEngine-specific and not loadable - return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")} + return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state") and not k.endswith(".inv_freq")} class NVEsmModel(NVEsmPreTrainedModel): @@ -350,16 +357,20 @@ class NVEsmModel(NVEsmPreTrainedModel): This model uses NVDIA's TransformerEngine to optimize attention layer training and inference. """ - def __init__(self, config: NVEsmConfig, add_pooling_layer: bool = True): + def __init__(self, config: NVEsmConfig, add_pooling_layer: Optional[bool] = None): """Initialize a NVEsmModel. Args: config (NVEsmConfig): The configuration of the model. - add_pooling_layer (bool): Whether to add a pooling layer. + add_pooling_layer (bool): Whether to add a pooling layer. If ``None``, + reads ``config.add_pooling_layer`` (defaults to ``True``). """ super().__init__(config) self.config = config + if add_pooling_layer is None: + add_pooling_layer = getattr(config, "add_pooling_layer", True) + # Ensure pad_token_id is set properly, defaulting to 0 if not specified if not hasattr(config, "pad_token_id") or config.pad_token_id is None: config.pad_token_id = 0 @@ -449,7 +460,9 @@ def forward( class NVEsmForMaskedLM(NVEsmPreTrainedModel): """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling.""" - _tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.decoder.weight": "esm.embeddings.word_embeddings.weight"} + _tied_weights_keys: ClassVar[dict[str, str]] = { + "lm_head.decoder.weight": "model.embeddings.word_embeddings.weight" + } _do_not_quantize = ("lm_head.dense", "lm_head.decoder") # Flag for testing that these layers are not quantized. def __init__(self, config: NVEsmConfig): @@ -466,7 +479,7 @@ def __init__(self, config: NVEsmConfig): "bi-directional self-attention." ) - self.esm = NVEsmModel(config, add_pooling_layer=False) + self.model = NVEsmModel(config, add_pooling_layer=False) self.lm_head = NVEsmLMHead(config) self.post_init() @@ -501,7 +514,7 @@ def forward( Returns: MaskedLMOutput: The output of the model. """ - outputs = self.esm( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -719,7 +732,7 @@ def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels - self.esm = NVEsmModel(config, add_pooling_layer=False) + self.model = NVEsmModel(config, add_pooling_layer=False) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = transformer_engine.pytorch.Linear( config.hidden_size, @@ -745,7 +758,7 @@ def forward( labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. """ - outputs = self.esm( + outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, diff --git a/bionemo-recipes/recipes/vllm_inference/.ci_build.sh b/bionemo-recipes/recipes/vllm_inference/.ci_build.sh new file mode 100755 index 0000000000..0d53640a3d --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/.ci_build.sh @@ -0,0 +1,4 @@ +#!/bin/bash -x +cd esm2 +PIP_CONSTRAINT= pip install -r requirements.txt +./install_vllm.sh diff --git a/bionemo-recipes/recipes/vllm_inference/.ci_test_env.sh b/bionemo-recipes/recipes/vllm_inference/.ci_test_env.sh new file mode 100644 index 0000000000..f667302446 --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/.ci_test_env.sh @@ -0,0 +1,2 @@ + +cd esm2 diff --git a/bionemo-recipes/recipes/vllm_inference/esm2/Dockerfile b/bionemo-recipes/recipes/vllm_inference/esm2/Dockerfile new file mode 100644 index 0000000000..d5697592f4 --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/esm2/Dockerfile @@ -0,0 +1,24 @@ +FROM nvcr.io/nvidia/pytorch:26.02-py3 +WORKDIR /workspace/bionemo +COPY . . +RUN --mount=type=cache,target=/root/.cache/pip \ + PIP_CONSTRAINT= pip install -r requirements.txt + +WORKDIR /workspace +ARG INSTALL_VLLM=false +ARG TORCH_CUDA_ARCH_LIST="" +ARG MAX_JOBS=8 +ARG UV_BREAK_SYSTEM_PACKAGES=1 +RUN if [ "$INSTALL_VLLM" = "true" ]; then \ + if [ -z "$TORCH_CUDA_ARCH_LIST" ]; then \ + echo "ERROR: TORCH_CUDA_ARCH_LIST must be set when INSTALL_VLLM=true" && exit 1; \ + fi && \ + git clone --branch v0.15.1 --depth 1 https://github.com/vllm-project/vllm.git && \ + cd vllm && \ + python use_existing_torch.py && \ + uv pip install -r requirements/build.txt --system && \ + uv pip install --no-build-isolation -e . --system && \ + pip install --upgrade "transformers[torch]"; \ + fi + +WORKDIR /workspace/bionemo diff --git a/bionemo-recipes/recipes/vllm_inference/esm2/README.md b/bionemo-recipes/recipes/vllm_inference/esm2/README.md new file mode 100644 index 0000000000..9d46b972eb --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/esm2/README.md @@ -0,0 +1,55 @@ +# ESM-2 vLLM Inference + +This recipe demonstrates running inference on +[ESM-2 TE checkpoints](../../../models/esm2/) using +[vLLM](https://github.com/vllm-project/vllm) (>= 0.14) as a pooling/embedding model. + +The exported TE checkpoints on HuggingFace Hub are directly compatible with vLLM. +No conversion scripts or weight renaming are needed: + +```python +from vllm import LLM + +model = LLM( + model="nvidia/esm2_t6_8M_UR50D", + runner="pooling", + trust_remote_code=True, + enforce_eager=True, + max_num_batched_tokens=1026, +) + +prompts = ["MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLK"] +outputs = model.embed(prompts) +print(outputs[0].outputs.embedding[:5]) +``` + +See [tests/test_vllm.py](tests/test_vllm.py) for a full golden-value validation across +vLLM, native HuggingFace, and the nvidia Hub reference model. + +## Installing vLLM in the container + +There are two ways to get vLLM installed in the Docker image. + +**Option 1: Build-time installation via Dockerfile build arg** + +Pass `--build-arg INSTALL_VLLM=true` and `--build-arg TORCH_CUDA_ARCH_LIST=` when +building the image. `TORCH_CUDA_ARCH_LIST` is required when `INSTALL_VLLM=true` (the +Dockerfile will error if it is not set): + +```bash +docker build -t esm2-vllm \ + --build-arg INSTALL_VLLM=true \ + --build-arg TORCH_CUDA_ARCH_LIST="9.0" . +``` + +**Option 2: Post-build installation via `install_vllm.sh`** + +Build the base image normally, then run `install_vllm.sh` inside the container. The script +auto-detects the GPU architecture, or you can pass an explicit arch argument: + +```bash +docker build -t esm2 . +docker run --rm -it --gpus all esm2 bash -c "./install_vllm.sh" +# or with an explicit architecture: +docker run --rm -it --gpus all esm2 bash -c "./install_vllm.sh 9.0" +``` diff --git a/bionemo-recipes/recipes/vllm_inference/esm2/convert.py b/bionemo-recipes/recipes/vllm_inference/esm2/convert.py new file mode 100644 index 0000000000..da0f650883 --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/esm2/convert.py @@ -0,0 +1,292 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Conversion utilities between HuggingFace ESM2 and TransformerEngine formats.""" + +import inspect + +import torch +from accelerate import init_empty_weights +from torch import nn +from transformers import EsmConfig, EsmForMaskedLM + +import state +from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM + + +mapping = { + "esm.encoder.layer.*.attention.output.dense.weight": "model.encoder.layers.*.self_attention.proj.weight", + "esm.encoder.layer.*.attention.output.dense.bias": "model.encoder.layers.*.self_attention.proj.bias", + "esm.encoder.layer.*.attention.LayerNorm.weight": "model.encoder.layers.*.self_attention.layernorm_qkv.layer_norm_weight", + "esm.encoder.layer.*.attention.LayerNorm.bias": "model.encoder.layers.*.self_attention.layernorm_qkv.layer_norm_bias", + "esm.encoder.layer.*.intermediate.dense.weight": "model.encoder.layers.*.layernorm_mlp.fc1_weight", + "esm.encoder.layer.*.intermediate.dense.bias": "model.encoder.layers.*.layernorm_mlp.fc1_bias", + "esm.encoder.layer.*.output.dense.weight": "model.encoder.layers.*.layernorm_mlp.fc2_weight", + "esm.encoder.layer.*.output.dense.bias": "model.encoder.layers.*.layernorm_mlp.fc2_bias", + "esm.encoder.layer.*.LayerNorm.weight": "model.encoder.layers.*.layernorm_mlp.layer_norm_weight", + "esm.encoder.layer.*.LayerNorm.bias": "model.encoder.layers.*.layernorm_mlp.layer_norm_bias", + "esm.encoder.emb_layer_norm_after.weight": "model.encoder.emb_layer_norm_after.weight", + "esm.encoder.emb_layer_norm_after.bias": "model.encoder.emb_layer_norm_after.bias", + "lm_head.dense.weight": "lm_head.dense.weight", + "lm_head.dense.bias": "lm_head.dense.bias", + "lm_head.layer_norm.weight": "lm_head.decoder.layer_norm_weight", + "lm_head.layer_norm.bias": "lm_head.decoder.layer_norm_bias", +} + +# Reverse mapping from TE to HF format by reversing the original mapping +reverse_mapping = {v: k for k, v in mapping.items()} + + +def convert_esm_hf_to_te(model_hf: nn.Module, **config_kwargs) -> nn.Module: + """Convert a Hugging Face model to a Transformer Engine model. + + Args: + model_hf (nn.Module): The Hugging Face model. + **config_kwargs: Additional configuration kwargs to be passed to NVEsmConfig. + + Returns: + nn.Module: The Transformer Engine model. + """ + # TODO (peter): this is super similar method to the AMPLIFY one, maybe we can abstract or keep simlar naming? models/amplify/src/amplify/state_dict_convert.py:convert_amplify_hf_to_te + te_config = NVEsmConfig(**model_hf.config.to_dict(), **config_kwargs) + with init_empty_weights(): + model_te = NVEsmForMaskedLM(te_config) + + output_model = state.apply_transforms( + model_hf, + model_te, + mapping, + [ + _pack_qkv_weight, + _pack_qkv_bias, + _pad_embeddings, + _pad_decoder_weights, + _pad_bias, + ], + ) + + return output_model + + +def convert_esm_te_to_hf(model_te: nn.Module, **config_kwargs) -> nn.Module: + """Convert a Transformer Engine model back to the original HuggingFace Facebook ESM-2 format. + + This function converts from the NVIDIA Transformer Engine (TE) format back to the + weight format compatible with the original facebook/esm2_* series of checkpoints. + The TE model is also a HuggingFace model, but this conversion ensures compatibility + with the original Facebook ESM-2 model architecture and weight format hosted on Hugging Face. + + Args: + model_te (nn.Module): The Transformer Engine model. + **config_kwargs: Additional configuration kwargs to be passed to EsmConfig. + + Returns: + nn.Module: The Hugging Face model in original Facebook ESM-2 format hosted on Hugging Face. + """ + # Convert TE config to HF config, filtering out TE-specific keys + te_config_dict = model_te.config.to_dict() + valid_keys = set(inspect.signature(EsmConfig.__init__).parameters) + filtered_config = {k: v for k, v in te_config_dict.items() if k in valid_keys} + hf_config = EsmConfig(**filtered_config, **config_kwargs) + + with init_empty_weights(): + model_hf = EsmForMaskedLM(hf_config) + + # Remove contact_head since it's not present in TE models + if hasattr(model_hf.esm, "contact_head"): + delattr(model_hf.esm, "contact_head") + + output_model = state.apply_transforms( + model_te, + model_hf, + reverse_mapping, + [_unpack_qkv_weight, _unpack_qkv_bias, _unpad_embeddings, _unpad_decoder_weights, _unpad_bias], + state_dict_ignored_entries=[ + "lm_head.decoder.weight", + "esm.contact_head.regression.weight", + "esm.contact_head.regression.bias", + ], + ) + + output_model.post_init() + + # Note: contact_head parameters are not preserved in TE models + # They are lost during HF -> TE conversion and cannot be recovered + # The converted model will not have the original contact_head weights + + return output_model + + +@state.state_transform( + source_key=( + "esm.encoder.layer.*.attention.self.query.weight", + "esm.encoder.layer.*.attention.self.key.weight", + "esm.encoder.layer.*.attention.self.value.weight", + ), + target_key="model.encoder.layers.*.self_attention.layernorm_qkv.weight", +) +def _pack_qkv_weight(ctx: state.TransformCTX, query, key, value): + """Pack separate Q, K, V weight tensors into a single interleaved QKV weight tensor.""" + concat_weights = torch.cat((query, key, value), dim=0) + input_shape = concat_weights.size() + num_heads = ctx.target.config.num_attention_heads + # transpose weights + # [sequence length, batch size, num_splits_model_parallel * attention head size * #attention heads] + # --> [sequence length, batch size, attention head size * num_splits_model_parallel * #attention heads] + concat_weights = concat_weights.view(3, num_heads, -1, query.size()[-1]) + concat_weights = concat_weights.transpose(0, 1).contiguous() + concat_weights = concat_weights.view(*input_shape) + return concat_weights + + +@state.state_transform( + source_key=( + "esm.encoder.layer.*.attention.self.query.bias", + "esm.encoder.layer.*.attention.self.key.bias", + "esm.encoder.layer.*.attention.self.value.bias", + ), + target_key="model.encoder.layers.*.self_attention.layernorm_qkv.bias", +) +def _pack_qkv_bias(ctx: state.TransformCTX, query, key, value): + """Pack separate Q, K, V bias tensors into a single interleaved QKV bias tensor.""" + concat_biases = torch.cat((query, key, value), dim=0) + input_shape = concat_biases.size() + num_heads = ctx.target.config.num_attention_heads + # transpose biases + # [num_splits_model_parallel * attention head size * #attention heads] + # --> [attention head size * num_splits_model_parallel * #attention heads] + concat_biases = concat_biases.view(3, num_heads, -1) + concat_biases = concat_biases.transpose(0, 1).contiguous() + concat_biases = concat_biases.view(*input_shape) + return concat_biases + + +@state.state_transform( + source_key="model.encoder.layers.*.self_attention.layernorm_qkv.weight", + target_key=( + "esm.encoder.layer.*.attention.self.query.weight", + "esm.encoder.layer.*.attention.self.key.weight", + "esm.encoder.layer.*.attention.self.value.weight", + ), +) +def _unpack_qkv_weight(ctx: state.TransformCTX, qkv_weight): + """Unpack fused QKV weights into separate [hidden_size, input_dim] tensors for query/key/value.""" + num_heads = ctx.source.config.num_attention_heads + total_rows, input_dim = qkv_weight.size() # size: [num_heads * 3 *head_dim, input_dim] + assert total_rows % (3 * num_heads) == 0, ( + f"QKV weight rows {total_rows} not divisible by 3*num_heads {3 * num_heads}" + ) + head_dim = total_rows // (3 * num_heads) + + qkv_weight = ( + qkv_weight.view(num_heads, 3, head_dim, input_dim).transpose(0, 1).contiguous() + ) # size: [3, num_heads, head_dim, input_dim] + query, key, value = qkv_weight[0], qkv_weight[1], qkv_weight[2] # size: [num_heads, head_dim, input_dim] + + query = query.reshape(-1, input_dim) # size: [num_heads * head_dim, input_dim] + key = key.reshape(-1, input_dim) # size: [num_heads * head_dim, input_dim] + value = value.reshape(-1, input_dim) # size: [num_heads * head_dim, input_dim] + + return query, key, value + + +@state.state_transform( + source_key="model.encoder.layers.*.self_attention.layernorm_qkv.bias", + target_key=( + "esm.encoder.layer.*.attention.self.query.bias", + "esm.encoder.layer.*.attention.self.key.bias", + "esm.encoder.layer.*.attention.self.value.bias", + ), +) +def _unpack_qkv_bias(ctx: state.TransformCTX, qkv_bias): + """Unpack fused QKV biases into separate [hidden_size] tensors for query/key/value.""" + num_heads = ctx.source.config.num_attention_heads + total_size = qkv_bias.size(0) # size: [num_heads * 3 * head_dim] + assert total_size % (3 * num_heads) == 0, ( + f"QKV bias size {total_size} not divisible by 3*num_heads {3 * num_heads}" + ) + head_dim = total_size // (3 * num_heads) + + qkv_bias = qkv_bias.view(num_heads, 3, head_dim).transpose(0, 1).contiguous() # size: [3, num_heads, head_dim] + query, key, value = qkv_bias[0], qkv_bias[1], qkv_bias[2] # size: [num_heads, head_dim] + + query = query.reshape(-1) # size: [num_heads * head_dim] + key = key.reshape(-1) # size: [num_heads * head_dim] + value = value.reshape(-1) # size: [num_heads * head_dim] + + return query, key, value + + +def _unpad_weights(ctx: state.TransformCTX, padded_embed): + """Remove padding from the embedding layer to get back to the original dimension.""" + target_embedding_dimension = ctx.target.config.vocab_size + return padded_embed[:target_embedding_dimension] + + +def _pad_weights(ctx: state.TransformCTX, source_embed): + """Pad the embedding layer to the new input dimension.""" + target_embedding_dimension = ctx.target.config.padded_vocab_size + hf_embedding_dimension = source_embed.size(0) + num_padding_rows = target_embedding_dimension - hf_embedding_dimension + padding_rows = torch.zeros( + num_padding_rows, source_embed.size(1), dtype=source_embed.dtype, device=source_embed.device + ) + return torch.cat((source_embed, padding_rows), dim=0) + + +_pad_embeddings = state.state_transform( + source_key="esm.embeddings.word_embeddings.weight", + target_key="model.embeddings.word_embeddings.weight", +)(_pad_weights) + +_pad_decoder_weights = state.state_transform( + source_key="lm_head.decoder.weight", + target_key="lm_head.decoder.weight", +)(_pad_weights) + +_unpad_embeddings = state.state_transform( + source_key="model.embeddings.word_embeddings.weight", + target_key="esm.embeddings.word_embeddings.weight", +)(_unpad_weights) + +_unpad_decoder_weights = state.state_transform( + source_key="lm_head.decoder.weight", + target_key="lm_head.decoder.weight", +)(_unpad_weights) + + +@state.state_transform( + source_key="lm_head.bias", + target_key="lm_head.decoder.bias", +) +def _pad_bias(ctx: state.TransformCTX, source_bias): + """Pad the embedding layer to the new input dimension.""" + target_embedding_dimension = ctx.target.config.padded_vocab_size + hf_embedding_dimension = source_bias.size(0) + output_bias = torch.finfo(source_bias.dtype).min * torch.ones( + target_embedding_dimension, dtype=source_bias.dtype, device=source_bias.device + ) + output_bias[:hf_embedding_dimension] = source_bias + return output_bias + + +@state.state_transform( + source_key="lm_head.decoder.bias", + target_key="lm_head.bias", +) +def _unpad_bias(ctx: state.TransformCTX, padded_bias): + """Remove padding from the bias to get back to the original dimension.""" + target_embedding_dimension = ctx.target.config.vocab_size + return padded_bias[:target_embedding_dimension] diff --git a/bionemo-recipes/recipes/vllm_inference/esm2/esm_fast_tokenizer/special_tokens_map.json b/bionemo-recipes/recipes/vllm_inference/esm2/esm_fast_tokenizer/special_tokens_map.json new file mode 100644 index 0000000000..9a725bd8b1 --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/esm2/esm_fast_tokenizer/special_tokens_map.json @@ -0,0 +1,7 @@ +{ + "cls_token": "", + "eos_token": "", + "mask_token": "", + "pad_token": "", + "unk_token": "" +} diff --git a/bionemo-recipes/recipes/vllm_inference/esm2/esm_fast_tokenizer/tokenizer.json b/bionemo-recipes/recipes/vllm_inference/esm2/esm_fast_tokenizer/tokenizer.json new file mode 100644 index 0000000000..2dc5f3e72b --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/esm2/esm_fast_tokenizer/tokenizer.json @@ -0,0 +1,168 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [ + { + "id": 0, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 1, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 2, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 3, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 32, + "content": "", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + } + ], + "normalizer": null, + "pre_tokenizer": { + "type": "Split", + "pattern": { + "String": "" + }, + "behavior": "Isolated", + "invert": false + }, + "post_processor": { + "type": "TemplateProcessing", + "single": [ + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + } + ], + "pair": [ + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "B", + "type_id": 1 + } + }, + { + "SpecialToken": { + "id": "", + "type_id": 1 + } + } + ], + "special_tokens": { + "": { + "id": "", + "ids": [0], + "tokens": [""] + }, + "": { + "id": "", + "ids": [2], + "tokens": [""] + } + } + }, + "decoder": null, + "model": { + "type": "WordLevel", + "vocab": { + "": 0, + "": 1, + "": 2, + "": 3, + "L": 4, + "A": 5, + "G": 6, + "V": 7, + "S": 8, + "E": 9, + "R": 10, + "T": 11, + "I": 12, + "D": 13, + "P": 14, + "K": 15, + "Q": 16, + "N": 17, + "F": 18, + "Y": 19, + "M": 20, + "H": 21, + "W": 22, + "C": 23, + "X": 24, + "B": 25, + "U": 26, + "Z": 27, + "O": 28, + ".": 29, + "-": 30, + "": 31, + "": 32 + }, + "unk_token": "" + } +} diff --git a/bionemo-recipes/recipes/vllm_inference/esm2/esm_fast_tokenizer/tokenizer_config.json b/bionemo-recipes/recipes/vllm_inference/esm2/esm_fast_tokenizer/tokenizer_config.json new file mode 100644 index 0000000000..d6b7bc43be --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/esm2/esm_fast_tokenizer/tokenizer_config.json @@ -0,0 +1,60 @@ +{ + "added_tokens_decoder": { + "0": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "2": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "3": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + } + }, + "clean_up_tokenization_spaces": false, + "bos_token": "", + "cls_token": "", + "eos_token": "", + "mask_token": "", + "pad_token": "", + "unk_token": "", + "extra_special_tokens": {}, + "model_max_length": 1000000000000000019884624838656, + "tokenizer_class": "PreTrainedTokenizerFast", + "add_bos_token": true, + "add_eos_token": true, + "model_input_names": [ + "input_ids", + "attention_mask" + ] +} diff --git a/bionemo-recipes/recipes/vllm_inference/esm2/export.py b/bionemo-recipes/recipes/vllm_inference/esm2/export.py new file mode 100644 index 0000000000..748e467843 --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/esm2/export.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import json +import shutil +from pathlib import Path + +import torch +from jinja2 import Template +from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer + +from convert import convert_esm_hf_to_te +from modeling_esm_te import AUTO_MAP + + +BENCHMARK_RESULTS = { + "esm2_t6_8M_UR50D": {"CAMEO": 0.48, "CASP14": 0.37}, + "esm2_t12_35M_UR50D": {"CAMEO": 0.56, "CASP14": 0.41}, + "esm2_t30_150M_UR50D": {"CAMEO": 0.65, "CASP14": 0.49}, + "esm2_t33_650M_UR50D": {"CAMEO": 0.70, "CASP14": 0.51}, + "esm2_t36_3B_UR50D": {"CAMEO": 0.72, "CASP14": 0.52}, + "esm2_t48_15B_UR50D": {"CAMEO": 0.72, "CASP14": 0.55}, +} + + +ESM_TAGS = [ + "esm2_t6_8M_UR50D", + "esm2_t12_35M_UR50D", + "esm2_t30_150M_UR50D", + "esm2_t33_650M_UR50D", + "esm2_t36_3B_UR50D", + "esm2_t48_15B_UR50D", +] + + +def format_parameter_count(num_params: int, sig: int = 1) -> str: + """Format parameter count in scientific notation (e.g., 6.5 x 10^8). + + Args: + num_params: Total number of parameters + sig: Number of digits to include after the decimal point + + Returns: + Formatted string in scientific notation + """ + s = f"{num_params:.{sig}e}" + base, exp = s.split("e") + return f"{base} x 10^{int(exp)}" + + +def export_hf_checkpoint(tag: str, export_path: Path): + """Export a Hugging Face checkpoint to a Transformer Engine checkpoint. + + Args: + tag: The tag of the checkpoint to export. + export_path: The parent path to export the checkpoint to. + """ + model_hf_masked_lm = AutoModelForMaskedLM.from_pretrained(f"facebook/{tag}") + model_hf = AutoModel.from_pretrained(f"facebook/{tag}") + model_hf_masked_lm.esm.pooler = model_hf.pooler + + # Export without vocab padding so the checkpoint stores embeddings at the real + # vocab_size. This avoids shape-mismatch errors in vLLM's VocabParallelEmbedding, + # which expects vocab_size-shaped weights. + model_te = convert_esm_hf_to_te(model_hf_masked_lm, padded_vocab_size=None) + model_te.save_pretrained(export_path / tag) + + tokenizer = AutoTokenizer.from_pretrained("esm_fast_tokenizer") # Use our PreTrainedTokenizerFast implementation. + tokenizer.save_pretrained(export_path / tag) + + # Patch the config + with open(export_path / tag / "config.json", "r") as f: + config = json.load(f) + + config["auto_map"] = AUTO_MAP + + with open(export_path / tag / "config.json", "w") as f: + json.dump(config, f, indent=2, sort_keys=True) + + shutil.copy("modeling_esm_te.py", export_path / tag / "esm_nv.py") + + # Calculate model parameters and render README template + num_params = sum(p.numel() for p in model_te.parameters()) + formatted_params = format_parameter_count(num_params) + + # Read and render the template + with open("model_readme.template", "r", encoding="utf-8") as f: + template_content = f.read() + + template = Template(template_content) + rendered_readme = template.render( + num_params=formatted_params, + model_tag=tag, + cameo_score=BENCHMARK_RESULTS[tag]["CAMEO"], + casp14_score=BENCHMARK_RESULTS[tag]["CASP14"], + ) + + # Write the rendered README + with open(export_path / tag / "README.md", "w") as f: + f.write(rendered_readme) + + shutil.copy("LICENSE", export_path / tag / "LICENSE") + + del model_hf, model_te, model_hf_masked_lm + gc.collect() + torch.cuda.empty_cache() + + # Smoke test that the model can be loaded. + model_te = AutoModelForMaskedLM.from_pretrained( + export_path / tag, + dtype=torch.bfloat16, + trust_remote_code=True, + ) + del model_te + gc.collect() + torch.cuda.empty_cache() + + +def main(): + """Export the ESM2 models from Hugging Face to the Transformer Engine format.""" + # TODO (peter): maybe add a way to specify the model to export or option to export all models? + for tag in ESM_TAGS: + print(f"Converting {tag}...") + export_hf_checkpoint(tag, Path("./checkpoint_export")) + + +if __name__ == "__main__": + main() diff --git a/bionemo-recipes/recipes/vllm_inference/esm2/install_vllm.sh b/bionemo-recipes/recipes/vllm_inference/esm2/install_vllm.sh new file mode 100755 index 0000000000..a761046837 --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/esm2/install_vllm.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +set -euo pipefail + +ARCH="${1:-$(python3 -c "import torch; cc = torch.cuda.get_device_capability(); print(f'{cc[0]}.{cc[1]}')")}" +MAX_JOBS="${MAX_JOBS:-8}" +export UV_BREAK_SYSTEM_PACKAGES=1 + +echo "Building vLLM for CUDA arch: $ARCH (MAX_JOBS=$MAX_JOBS)" + +cd /workspace +if [ ! -d vllm ]; then + git clone --branch v0.15.1 --depth 1 https://github.com/vllm-project/vllm.git +fi +cd vllm +python use_existing_torch.py +TORCH_CUDA_ARCH_LIST="$ARCH" MAX_JOBS="$MAX_JOBS" \ + uv pip install -r requirements/build.txt --system +TORCH_CUDA_ARCH_LIST="$ARCH" MAX_JOBS="$MAX_JOBS" \ + uv pip install --no-build-isolation -e . --system +pip install --upgrade "transformers[torch]" + +echo "vLLM installed for arch $ARCH" diff --git a/bionemo-recipes/recipes/vllm_inference/esm2/modeling_esm_te.py b/bionemo-recipes/recipes/vllm_inference/esm2/modeling_esm_te.py new file mode 100644 index 0000000000..298fc70f6b --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/esm2/modeling_esm_te.py @@ -0,0 +1,786 @@ +# noqa: license-check +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved. +# Copyright 2025 NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""TransformerEngine-optimized ESM model. + +Adapted from `modeling_esm.py` in huggingface/transformers. +""" + +import warnings +from contextlib import nullcontext +from typing import ClassVar, Literal, Optional, Unpack + +# TODO: put import guard around transformer_engine here, with an informative error message around +# installation and the nvidia docker container. +import torch +import transformer_engine.common.recipe +import transformer_engine.pytorch +from torch import nn +from torch.nn import CrossEntropyLoss +from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPooling, + MaskedLMOutput, + TokenClassifierOutput, +) +from transformers.models.esm.configuration_esm import EsmConfig +from transformers.models.esm.modeling_esm import EsmPooler, EsmPreTrainedModel +from transformers.utils import logging +from transformers.utils.generic import TransformersKwargs + + +logger = logging.get_logger(__name__) + +# Dictionary that gets inserted into config.json to map Auto** classes to our TE-optimized model classes defined below. +# These should be prefixed with esm_nv., since we name the file esm_nv.py in our exported checkpoints. +AUTO_MAP = { + "AutoConfig": "esm_nv.NVEsmConfig", + "AutoModel": "esm_nv.NVEsmModel", + "AutoModelForMaskedLM": "esm_nv.NVEsmForMaskedLM", + "AutoModelForTokenClassification": "esm_nv.NVEsmForTokenClassification", +} + + +class NVEsmConfig(EsmConfig): + """NVEsmConfig is a configuration for the NVEsm model.""" + + model_type: str = "nv_esm" + + def __init__( + self, + qkv_weight_interleaved: bool = True, + encoder_activation: str = "gelu", + attn_input_format: Literal["bshd", "thd"] = "bshd", + fuse_qkv_params: bool = True, + micro_batch_size: Optional[int] = None, + max_seq_length: Optional[int] = None, + padded_vocab_size: Optional[int] = 64, + attn_mask_type: str = "padding", + add_pooling_layer: bool = False, + layer_precision: list[str | None] | None = None, + **kwargs, + ): + """Initialize the NVEsmConfig with additional TE-related config options. + + Args: + qkv_weight_interleaved: Whether to interleave the qkv weights. If set to `False`, the + QKV weight is interpreted as a concatenation of query, key, and value weights along + the `0th` dimension. The default interpretation is that the individual `q`, `k`, and + `v` weights for each attention head are interleaved. This parameter is set to `False` + when using :attr:`fuse_qkv_params=False`. + encoder_activation: The activation function to use in the encoder. + attn_input_format: The input format to use for the attention: + "bshd" = Batch, Sequence, Head, Dimension (standard padded format) + "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format) + Note that these formats are very closely related to the `qkv_format` in the + `MultiHeadAttention` and `DotProductAttention` modules. + fuse_qkv_params: Whether to fuse the qkv parameters. If set to `True`, + `TransformerLayer` module exposes a single fused parameter for query-key-value. + This enables optimizations such as QKV fusion without concatentations/splits and + also enables the argument `fuse_wgrad_accumulation`. + micro_batch_size: The micro batch size to use for the attention. This is needed for + JIT Warmup, a technique where jit fused functions are warmed up before training to + ensure same kernels are used for forward propogation and activation recompute phase. + max_seq_length: The maximum sequence length to use for the attention. This is needed for + JIT Warmup, a technique where jit fused functions are warmed up before training to + ensure same kernels are used for forward propogation and activation recompute phase. + padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults + to vocab_size. Must be greater than or equal to vocab_size. + attn_mask_type: The type of attention mask to use. + add_pooling_layer: Whether the base model should include a pooling layer. + Defaults to ``False`` because exported checkpoints do not contain pooler + weights. Set to ``True`` only if you have a checkpoint with pooler weights. + layer_precision: Per-layer quantization precision, a list of length ``num_hidden_layers`` + where each element is ``"fp8"``, ``"fp4"``, or ``None`` (BF16 fallback). ``None`` + (the default) means no quantization is configured. + **kwargs: Additional config options to pass to EsmConfig. + """ + super().__init__(**kwargs) + # Additional TE-related config options. + self.qkv_weight_interleaved = qkv_weight_interleaved + self.encoder_activation = encoder_activation + self.attn_input_format = attn_input_format + self.fuse_qkv_params = fuse_qkv_params + self.micro_batch_size = micro_batch_size + self.max_seq_length = max_seq_length + self.attn_mask_type = attn_mask_type + self.add_pooling_layer = add_pooling_layer + self.layer_precision = layer_precision + + # Set padded_vocab_size with default fallback to vocab_size + self.padded_vocab_size = padded_vocab_size if padded_vocab_size is not None else self.vocab_size + + # Ensure padded_vocab_size is at least as large as vocab_size + if self.padded_vocab_size is not None and self.vocab_size is not None: + assert self.padded_vocab_size >= self.vocab_size, ( + f"padded_vocab_size ({self.padded_vocab_size}) must be greater than or equal to vocab_size ({self.vocab_size})" + ) + + +class NVEsmEncoder(nn.Module): + """NVEsmEncoder is a TransformerEngine-optimized ESM encoder.""" + + def __init__(self, config: NVEsmConfig): + """Initialize a NVEsmEncoder. + + Args: + config (NVEsmConfig): The configuration of the model. + """ + super().__init__() + self.config = config + + def _init_method(x): + torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range) + + self.layers = nn.ModuleList( + [ + transformer_engine.pytorch.TransformerLayer( + hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + num_attention_heads=config.num_attention_heads, + layernorm_epsilon=config.layer_norm_eps, + hidden_dropout=config.hidden_dropout_prob, + attention_dropout=config.attention_probs_dropout_prob, + qkv_weight_interleaved=config.qkv_weight_interleaved, + layer_number=i + 1, + layer_type="encoder", + self_attn_mask_type=config.attn_mask_type, + activation=config.encoder_activation, + attn_input_format=config.attn_input_format, + seq_length=config.max_seq_length, + micro_batch_size=config.micro_batch_size, + num_gqa_groups=config.num_attention_heads, + fuse_qkv_params=config.fuse_qkv_params, + params_dtype=config.dtype, + window_size=(-1, -1), + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=_init_method, + output_layer_init_method=_init_method, + ) + for i in range(config.num_hidden_layers) + ] + ) + self._fp8_recipe: transformer_engine.common.recipe.Recipe | None = None + self._fp4_recipe: transformer_engine.common.recipe.Recipe | None = None + self.emb_layer_norm_after = transformer_engine.pytorch.LayerNorm( + config.hidden_size, + eps=config.layer_norm_eps, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + ) + if config.position_embedding_type == "rotary": + self.rotary_embeddings = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads) + + def set_recipes( + self, + fp8_recipe: transformer_engine.common.recipe.Recipe | None = None, + fp4_recipe: transformer_engine.common.recipe.Recipe | None = None, + ) -> None: + """Attach quantization recipe objects for per-layer autocast. + + Recipes are not serializable and must be set at runtime after model creation + and sharding (FSDP/DDP/mFSDP) but before training. The per-layer precision + assignments are read from ``self.config.layer_precision``. + + These recipes are also hardware specific, so we should not store them as + attributes of the model and attach them at runtime. + + Args: + fp8_recipe: The FP8 recipe instance (e.g., MXFP8BlockScaling), or None. + fp4_recipe: The FP4 recipe instance (e.g., NVFP4BlockScaling), or None. + """ + self._fp8_recipe = fp8_recipe + self._fp4_recipe = fp4_recipe + + def get_layer_autocast(self, layer_number: int): + """Return the appropriate TE autocast context manager for a given layer. + + The context interacts with the outer FP8 autocast in the training script: + - FP8 layer: nullcontext() -- lets the outer FP8 autocast take effect. + - FP4 layer: te.pytorch.autocast(enabled=True, recipe=fp4_recipe) -- overrides to FP4. + - BF16 layer: te.pytorch.autocast(enabled=False) -- disables quantized compute. + + Args: + layer_number: The 0-indexed layer number. + + Returns: + A context manager for the layer's quantization mode. + """ + precision = self.config.layer_precision[layer_number] if self.config.layer_precision is not None else None + if precision == "fp8": + return nullcontext() + elif precision == "fp4": + return transformer_engine.pytorch.autocast(enabled=True, recipe=self._fp4_recipe) + else: + return transformer_engine.pytorch.autocast(enabled=False) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ): + """Forward pass of the NVEsmEncoder. + + Args: + hidden_states (torch.Tensor): The hidden states. + attention_mask (torch.Tensor): The attention mask. + **kwargs: Additional arguments, see TransformersKwargs for more details. + """ + all_hidden_states: tuple[torch.Tensor, ...] = () + + if self.config.attn_input_format == "thd" and hidden_states.dim() == 3 and hidden_states.size(0) == 1: + # For THD, the embedding output is a 3-dimensional tensor with shape [1, total_tokens, hidden_size], but TE + # expects a 2-dimensional tensor with shape [total_tokens, hidden_size]. + hidden_states = hidden_states.squeeze(0) + + # Ensure that rotary embeddings are computed with at a higher precision outside the torch autocast context. + with torch.autocast(device_type="cuda", enabled=False): + te_rope_emb = self.rotary_embeddings(max_seq_len=self.config.max_position_embeddings) + te_rope_emb = te_rope_emb.to(hidden_states.device, non_blocking=True) + if te_rope_emb.dtype == torch.float32: + warnings.warn("Rotary embeddings should be in float32 for optimal performance.", UserWarning) + + # Outer FP8 autocast enables FP8 compute for the encoder stack. Per-layer overrides (FP4, BF16) are handled + # by get_layer_autocast(), which nests inside this context. + with transformer_engine.pytorch.autocast(enabled=self._fp8_recipe is not None, recipe=self._fp8_recipe): + for layer_number, layer_module in enumerate(self.layers): + if kwargs.get("output_hidden_states", False): + all_hidden_states = (*all_hidden_states, hidden_states) + + with self.get_layer_autocast(layer_number): + hidden_states = layer_module( + hidden_states, + attention_mask, + rotary_pos_emb=te_rope_emb, + cu_seqlens_q=kwargs.get("cu_seq_lens_q", None), + cu_seqlens_kv=kwargs.get("cu_seq_lens_k", None), + cu_seqlens_q_padded=kwargs.get("cu_seq_lens_q_padded", None), + cu_seqlens_kv_padded=kwargs.get("cu_seq_lens_k_padded", None), + max_seqlen_q=kwargs.get("max_length_q", None), + max_seqlen_kv=kwargs.get("max_length_k", None), + pad_between_seqs=kwargs.get("pad_between_seqs", None), + ) + + hidden_states = self.emb_layer_norm_after(hidden_states) + + if kwargs.get("output_hidden_states", False): + all_hidden_states = (*all_hidden_states, hidden_states) + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states if all_hidden_states else None, + ) + + +class NVEsmPreTrainedModel(EsmPreTrainedModel): + """An abstract class to handle weights initialization and pretrained model loading.""" + + config_class = NVEsmConfig + base_model_prefix = "model" + supports_gradient_checkpointing = False + accepts_loss_kwargs = False + _no_split_modules = ( + "TransformerLayer", + "EsmEmbeddings", + ) + + def init_empty_weights(self): + """Handles moving the model from the meta device to the cuda device and initializing the weights.""" + # For TE layers, calling `reset_parameters` is sufficient to move them to the cuda device and apply the weight + # initialization we passed them during module creation. + for module in self.modules(): + if hasattr(module, "reset_parameters"): + module.reset_parameters() + + # The embeddings layer is the only non-TE layer in this model we need to deal with. We use + # `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard + # deviation. self.base_model resolves to self.model for wrapper classes or self for NVEsmModel. + self.base_model.embeddings.word_embeddings.to_empty(device="cuda") + self.base_model.embeddings.apply(self._init_weights) + + # Meta-device init seems to break weight tying, so we re-tie the weights here. + self.tie_weights() + + def _init_weights(self, module): + """Initialize module weights. + + We only use this method for standard pytorch modules, TE modules handle their own weight initialization through + `init_method` parameters and the `reset_parameters` method. + """ + if module.__module__.startswith("transformer_engine.pytorch"): + # Notably, we need to avoid calling the parent method for TE modules, since the default _init_weights will + # assume any class with `LayerNorm` in the name should have weights initialized to 1.0; breaking + # `LayerNormLinear` and `LayerNormMLP` modules that use `weight` for the linear layer and + # `layer_norm_weight` for the layer norm. Instead, we call `reset_parameters` if the module has it and the + # weights are not in fp8. We still need to figure out why this raises an error if we're using + # `quantized_model_init`. + if hasattr(module, "reset_parameters") and not getattr(module, "primary_weights_in_fp8", False): + module.reset_parameters() + return + + super()._init_weights(module) + + def state_dict(self, *args, **kwargs): + """Override state_dict to filter out non-loadable keys. + + Filters out: + - ``_extra_state`` keys: TransformerEngine-specific, not loadable by HuggingFace v5. + - ``.inv_freq`` buffers: Computed at init time by RotaryPositionEmbedding, not needed + in the checkpoint and not loadable by vLLM's AutoWeightsLoader (which only iterates + over ``named_parameters``, not ``named_buffers``). + """ + state_dict = super().state_dict(*args, **kwargs) + return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state") and not k.endswith(".inv_freq")} + + +class NVEsmModel(NVEsmPreTrainedModel): + """The ESM Encoder-only protein language model. + + This model uses NVDIA's TransformerEngine to optimize attention layer training and inference. + """ + + def __init__(self, config: NVEsmConfig, add_pooling_layer: Optional[bool] = None): + """Initialize a NVEsmModel. + + Args: + config (NVEsmConfig): The configuration of the model. + add_pooling_layer (bool): Whether to add a pooling layer. If ``None``, + reads ``config.add_pooling_layer`` (defaults to ``True``). + """ + super().__init__(config) + self.config = config + + if add_pooling_layer is None: + add_pooling_layer = getattr(config, "add_pooling_layer", True) + + # Ensure pad_token_id is set properly, defaulting to 0 if not specified + if not hasattr(config, "pad_token_id") or config.pad_token_id is None: + config.pad_token_id = 0 + self.embeddings = NVEsmEmbeddings(config) + self.encoder = NVEsmEncoder(config) + self.pooler = EsmPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + """Get the input embeddings of the model.""" + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value: torch.Tensor): + """Set the input embeddings of the model. + + Args: + value (torch.Tensor): The input embeddings. + """ + self.embeddings.word_embeddings = value + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPooling: + """Forward pass of the NVEsmModel. + + Args: + input_ids (torch.Tensor): The input ids. + attention_mask (torch.Tensor): The attention mask. + position_ids (torch.Tensor): The position ids. + inputs_embeds (torch.Tensor): The input embeddings. + **kwargs: Additional arguments, see TransformersKwargs for more details. + + Returns: + BaseModelOutputWithPooling: The output of the model. + """ + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length)), device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # TE expects a boolean attention mask, where 1s are masked and 0s are not masked + extended_attention_mask = extended_attention_mask < -1 + + embedding_output = self.embeddings( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + **kwargs, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=None if self.config.attn_input_format == "thd" else extended_attention_mask, + **kwargs, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + ) + + +class NVEsmForMaskedLM(NVEsmPreTrainedModel): + """NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling.""" + + _tied_weights_keys: ClassVar[dict[str, str]] = { + "lm_head.decoder.weight": "model.embeddings.word_embeddings.weight" + } + _do_not_quantize = ("lm_head.dense", "lm_head.decoder") # Flag for testing that these layers are not quantized. + + def __init__(self, config: NVEsmConfig): + """Initialize a NVEsmForMaskedLM. + + Args: + config (NVEsmConfig): The configuration of the model. + """ + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `EsmForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.model = NVEsmModel(config, add_pooling_layer=False) + self.lm_head = NVEsmLMHead(config) + + self.post_init() + + def get_output_embeddings(self): + """Get the output embeddings of the model.""" + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + """Set the output embeddings of the model.""" + self.lm_head.decoder = new_embeddings + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MaskedLMOutput: + """Forward pass of the NVEsmForMaskedLM. + + Args: + input_ids (torch.LongTensor): The input ids. + attention_mask (torch.Tensor): The attention mask. + position_ids (torch.LongTensor): The position ids. + inputs_embeds (torch.FloatTensor): The input embeddings. + labels (torch.LongTensor): The labels. + **kwargs: Additional arguments, see TransformersKwargs for more details. + + Returns: + MaskedLMOutput: The output of the model. + """ + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + sequence_output = outputs[0] + with transformer_engine.pytorch.autocast(enabled=False): + prediction_scores = self.lm_head(sequence_output) + + # Truncate logits back to original vocab_size if padding was used + if self.config.padded_vocab_size != self.config.vocab_size: + prediction_scores = prediction_scores[..., : self.config.vocab_size] + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), + labels.to(prediction_scores.device).view(-1), + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + ) + + +class NVEsmLMHead(nn.Module): + """ESM Head for masked language modeling using TransformerEngine.""" + + def __init__(self, config: NVEsmConfig): + """Initialize a NVEsmLMHead. + + Args: + config (NVEsmConfig): The configuration of the model. + """ + super().__init__() + with transformer_engine.pytorch.quantized_model_init(enabled=False): + self.dense = transformer_engine.pytorch.Linear( + config.hidden_size, + config.hidden_size, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + self.decoder = transformer_engine.pytorch.LayerNormLinear( + config.hidden_size, + config.padded_vocab_size if config.padded_vocab_size is not None else config.vocab_size, + bias=True, + eps=config.layer_norm_eps, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + def forward(self, features, **kwargs): + """Forward pass of the NVEsmLMHead. + + Args: + features (torch.Tensor): The features. + **kwargs: Additional arguments. + """ + # Keep the last layers of the network in higher precision to avoid numerical instability. + # Please see recipes/fp8_analysis/README.md for more details. + with transformer_engine.pytorch.autocast(enabled=False): + x = self.dense(features) + x = torch.nn.functional.gelu(x) + x = self.decoder(x) + return x + + +class NVEsmEmbeddings(nn.Module): + """Modified version of EsmEmbeddings to support THD inputs.""" + + def __init__(self, config): + """Initialize a NVEsmEmbeddings.""" + super().__init__() + self.word_embeddings = nn.Embedding( + config.padded_vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id, + dtype=config.dtype, + ) + + self.layer_norm = ( + transformer_engine.pytorch.LayerNorm( + config.hidden_size, + eps=config.layer_norm_eps, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + ) + if config.emb_layer_norm_before + else None + ) + + if config.position_embedding_type != "rotary": + raise ValueError( + "The TE-accelerated ESM-2 model only supports rotary position embeddings, received " + f"{config.position_embedding_type}" + ) + + self.padding_idx = config.pad_token_id + self.token_dropout = config.token_dropout + self.mask_token_id = config.mask_token_id + + def _apply_token_dropout_bshd(self, embeddings, input_ids, attention_mask): + """Apply token dropout scaling for BSHD-format inputs. + + Compensates for masked tokens by scaling unmasked embeddings based on the + observed mask ratio per sequence. + + Args: + embeddings: Token embeddings with masked positions already zeroed out. + input_ids: Original input token IDs. + attention_mask: Attention mask indicating valid tokens. + + Returns: + Scaled embeddings tensor. + """ + mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs + src_lengths = attention_mask.sum(-1) if attention_mask is not None else input_ids.shape[1] + n_masked_per_seq = (input_ids == self.mask_token_id).sum(-1).float() + mask_ratio_observed = n_masked_per_seq / src_lengths + scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) + return (embeddings * scale_factor[:, None, None]).to(embeddings.dtype) + + def _apply_token_dropout_thd(self, embeddings, input_ids, kwargs): + """Apply token dropout scaling for THD-format (packed sequence) inputs. + + Uses cumulative sequence lengths to compute per-sequence mask ratios and + scales embeddings accordingly using repeat_interleave. + + Args: + embeddings: Token embeddings with masked positions already zeroed out. + input_ids: Original input token IDs. + kwargs: Additional keyword arguments containing cu_seq_lens_q and optionally cu_seq_lens_q_padded. + + Returns: + Scaled embeddings tensor. + """ + mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all ESM model training runs + src_lengths = torch.diff(kwargs["cu_seq_lens_q"]) + if "cu_seq_lens_q_padded" in kwargs: + src_lengths_padded = torch.diff(kwargs["cu_seq_lens_q_padded"]) + else: + src_lengths_padded = src_lengths + # We need to find the number of masked tokens in each sequence in the padded batch. + is_masked = (input_ids == self.mask_token_id).squeeze(0) + n_masked_per_seq = torch.nested.nested_tensor_from_jagged(is_masked, offsets=kwargs["cu_seq_lens_q"]).sum(1) + mask_ratio_observed = n_masked_per_seq.float() / src_lengths + scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed) + reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths_padded, dim=0) + return (embeddings * reshaped_scale_factor.unsqueeze(-1)).to(embeddings.dtype) + + def forward( + self, + input_ids=None, + attention_mask=None, + inputs_embeds=None, + **kwargs: Unpack[TransformersKwargs], + ): + """Forward pass of the NVEsmEmbeddings.""" + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + # Note that if we want to support ESM-1 (not 1b!) in future then we need to support an + # embedding_scale factor here. + embeddings = inputs_embeds + + if ( + kwargs.get("cu_seq_lens_q") is not None + and kwargs.get("cu_seq_lens_k") is not None + and kwargs.get("max_length_q") is not None + and kwargs.get("max_length_k") is not None + ): + using_thd = True + attention_mask = None + else: + using_thd = False + + # Matt: ESM has the option to handle masking in MLM in a slightly unusual way. If the token_dropout + # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however, + # masked tokens are treated as if they were selected for input dropout and zeroed out. + # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by + # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample). + # This is analogous to the way that dropout layers scale down outputs during evaluation when not + # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training). + if self.token_dropout and input_ids is not None: + embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0) + if using_thd: + embeddings = self._apply_token_dropout_thd(embeddings, input_ids, kwargs) + else: + embeddings = self._apply_token_dropout_bshd(embeddings, input_ids, attention_mask) + + if self.layer_norm is not None: + embeddings = self.layer_norm(embeddings) + + if attention_mask is not None: + embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype) + + return embeddings + + +class NVEsmForTokenClassification(NVEsmPreTrainedModel): + """Adds a token classification head to the model. + + Adapted from EsmForTokenClassification in Hugging Face Transformers `modeling_esm.py`. + """ + + def __init__(self, config): + """Initialize NVEsmForTokenClassification.""" + super().__init__(config) + self.num_labels = config.num_labels + + self.model = NVEsmModel(config, add_pooling_layer=False) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = transformer_engine.pytorch.Linear( + config.hidden_size, + config.num_labels, + params_dtype=config.dtype, + device="meta" if torch.get_default_device() == torch.device("meta") else "cuda", + init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range), + ) + + self.post_init() + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> TokenClassifierOutput: + """Forward pass for the token classification head. + + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + + labels = labels.to(logits.device) + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/bionemo-recipes/recipes/vllm_inference/esm2/requirements.txt b/bionemo-recipes/recipes/vllm_inference/esm2/requirements.txt new file mode 100644 index 0000000000..b4c358680a --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/esm2/requirements.txt @@ -0,0 +1,11 @@ +accelerate +datasets +hydra-core +jinja2 +megatron-fsdp +omegaconf +peft +torch +torchao!=0.14.0 +transformer_engine[pytorch] +transformers diff --git a/bionemo-recipes/recipes/vllm_inference/esm2/state.py b/bionemo-recipes/recipes/vllm_inference/esm2/state.py new file mode 100644 index 0000000000..bda08c4d79 --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/esm2/state.py @@ -0,0 +1,724 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""State dict conversion utilities adapted from nemo.lightning.io.state. + +This module provides the transform system used by convert.py to map state dicts between model formats: + +- ``mapping``: A dict of simple key renames (source_key -> target_key). Each source key is copied directly + to the corresponding target key with no modification to the tensor values. + +- ``transforms``: A list of ``StateDictTransform`` objects for multi-key merges and splits. These handle + cases where multiple source keys must be combined into one target key (e.g., merging Q/K/V into fused QKV), + or one source key must be split into multiple target keys. + + Important: When ``source_key`` is a tuple (many-to-one merge), the transform function's parameter names + are used to map each source key to a function argument. This means ``*args`` style parameters do not work; + each parameter must be explicitly named (e.g., ``def fn(q, k, v)`` not ``def fn(*args)``). +""" + +import inspect +import logging +import re +from dataclasses import dataclass +from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union, overload + +import numpy as np +import torch +from torch import nn + + +logger = logging.getLogger(__name__) + +SourceModuleT = TypeVar("SourceModuleT", bound=nn.Module) +TargetModuleT = TypeVar("TargetModuleT", bound=nn.Module) +F = TypeVar("F", bound=Callable[..., Any]) + + +@dataclass +class TransformCTX: + """Transform Data class Definition.""" + + source: nn.Module + source_state: dict + target: nn.Module + target_state: dict + + +class _ModelState: + """Helper class for used for to modify state dict of a source model during model conversion.""" + + def __init__(self, state_dict, config=None): + self._state_dict = state_dict + self.config = config + + def state_dict(self): + # pylint: disable=C0115,C0116 + return self._state_dict + + def to(self, dtype): + # pylint: disable=C0115,C0116 + for k, v in self._state_dict.items(): + if v.dtype != dtype: + logger.warning(f"Converting {k} from {v.dtype} (source model) to {dtype} (target model)") + self._state_dict[k] = v.to(dtype) + + +@torch.no_grad +def apply_transforms( + source: Union[nn.Module, _ModelState], + target: TargetModuleT, + mapping: Dict[str, str], + transforms: Optional[List[Callable[[TransformCTX], TransformCTX]]] = None, + state_dict_ignored_entries: Optional[List] = None, + cast_dtype: Optional[torch.dtype] = None, +) -> TargetModuleT: + """Transform the state dictionary of a source module to match the structure of a target module's state dictionary. + + This function renames keys according to a provided mapping and modifies values using a list + of transformation functions. Each transformation function typically is decorated + with `io.state_transform`. + + Args: + source (nn.Module): The source module from which parameters and buffers are taken. + target (TargetModuleT): The target module to which parameters and buffers are adapted. + mapping (Dict[str, str]): Key-value pairs where each key from the source state dictionary + is mapped to a corresponding key in the target state dictionary. + transforms (Optional[List[Callable[[TransformCTX], TransformCTX]]]): A list of functions + that modify the `TransformCTX` object. If None, no transformations beyond key renaming + are applied. Defaults to None. + state_dict_ignored_entries: List of entries to ignore in _target.state_dict(). There are cases + where multiple entries in model's state_dict point to one entry in model's named_parameter. + E.g., model has multiple pointers pointing to one shared parameters (`encoder.embed_tokens.weight`, + `decoder.embed_tokens.weight` and `shared.weight` all points to `shared.weight + in T5 Huggingface implementation.). In these cases, ignore redundant entries. + cast_dtype: case the output state dict to a certain precision. + + Returns: + TargetModuleT: The modified target module with its state dictionary adjusted according to + the specified mappings and transformations. + + Raises: + ValueError: If there's a mismatch in shape between corresponding source and target parameters + or buffers. + RuntimeError: If the target state dictionary contains keys that are not present in the source + state dictionary after all transformations. + + Examples: + >>> source_module = nn.Linear(10, 5) + >>> target_module = nn.Linear(10, 5) + >>> mapping = {'weight': 'weights', 'bias': 'biases'} + @io.state_transform( + source_key="weight", + target_key="weights" + ) + def scale_weights(ctx): + ctx.target_state['weights'] = ctx.source_state['weight'] * 2 + return ctx + >>> transformed_target = apply_transforms( + ... source_module, target_module, mapping, [scale_weights] + ... ) + >>> print(transformed_target.state_dict()['weights']) + + See Also: + - `TransformCTX`: For more details on the context object used in transformations. + - `StateDictTransform`: For creating complex transformations. + + Note: + This function is particularly useful when adapting models from different frameworks or + when consolidating models with different architectural changes. + """ + if transforms is None: + transforms = [] + if state_dict_ignored_entries is None: + state_dict_ignored_entries = [] + + # Track dtypes to make sure they weren't modified during conversion. + target_orig_dtypes = extract_dtypes(target.named_parameters()) + + target_state = target.state_dict() + ctx = TransformCTX( + source=source, + source_state=source.state_dict(), + target=target, + target_state=target_state, + ) + + for key, val in mapping.items(): + logger.debug(f"Mapping {key} -> {val}") + ctx = StateDictTransform(key, val)(ctx) + + for transform in transforms: + logger.debug(f"Transforming {transform.source_key} -> {transform.target_key}") + ctx = transform(ctx) + + _params: Dict[str, nn.Parameter] = {} + for name, param in target.named_parameters(): + if name in target_state: + target_param = target_state[name] + if param.data.shape != target_param.shape: + raise ValueError( + f"Shape mismatch for parameter {name}: target shape {param.shape} vs " + f"converted source shape {target_param.shape}" + ) + + _params[name] = nn.Parameter(target_param, requires_grad=param.requires_grad) + target_state.pop(name) + else: + print(f"Unexpected key: {name} not in target model but is in source model.") + + for key, val in _params.items(): + _module, _key = target, key + if "." in key: + for part in key.split(".")[:-1]: + _module = getattr(_module, part) + _key = key.split(".")[-1] + + _module.register_parameter(_key, val) + + _buffers = {} + for name, buffer in target.named_buffers(): + if name in target_state: + if buffer.shape != target_state[name].shape: + raise ValueError(f"Shape mismatch for buffer {name}: {buffer.shape} vs {target_state[name].shape}") + + _buffers[name] = nn.Parameter(target_state[name], requires_grad=False) + target_state.pop(name) + + for key, val in _buffers.items(): + _module, _key = target, key + if "." in key: + for part in key.split(".")[:-1]: + _module = getattr(_module, part) + _key = key.split(".")[-1] + + _module.register_buffer(_key, val) + + keys = list(filter(lambda x: x is not None and not x.endswith("_extra_state"), target_state.keys())) + keys = [key for key in keys if key not in state_dict_ignored_entries] + if len(keys) != 0: + raise RuntimeError(f"Additional keys: {keys} in target model but not in source model.") + + if hasattr(target, "tie_weights"): + target.tie_weights() + + meta_tensor_keys = [] + for name, param in target.named_parameters(): + if param.is_meta: + meta_tensor_keys.append(name) + + assert not meta_tensor_keys, ( + f"{meta_tensor_keys}\nThere are meta tensors in the model after conversion." + f"Did you forget to include these parameters in the mapping or transforms in `convert_state`?" + ) + + if cast_dtype: + logger.info(f"Casting model to {cast_dtype}...") + target.to(cast_dtype) + logger.info(f"Casting model to {cast_dtype} complete.") + else: + target_new_dtypes = extract_dtypes(target.named_parameters()) + for key in target_orig_dtypes.keys(): + if key in target_new_dtypes: # For tied weights, these parameters may disappear. + assert target_orig_dtypes[key] == target_new_dtypes[key], ( + f"dtype mismatch for key {key}: {target_orig_dtypes[key]} vs {target_new_dtypes[key]}" + ) + + return target + + +def _default_transform(inp): + return inp + + +class StateDictTransform(Generic[F]): + """A transformation class for state dictionaries. + + Allows for flexible key matching and transformation of values between source and target state dictionaries. + + Attributes: + source_key: A string, tuple of strings, or a dictionary specifying the keys in the source + state dictionary to match. Wildcards (*) are supported. + target_key: A string or tuple of strings specifying the keys in the target state dictionary + to match. Wildcards (*) are supported. + transform: A callable that performs the transformation on matched keys' values. + + Examples: + >>> def example_transform(ctx, *args): + ... return sum(args) + >>> transform = StateDictTransform( + ... source_key="model.layers.*.self_attn.*_proj.weight", + ... target_key="decoder.layers.*.self_attention.linear_qkv.weight", + ... transform=example_transform + ... ) + """ + + def __init__( + self, + source_key: Union[str, Tuple[str, ...], Dict[str, str]], + target_key: Union[str, Tuple[str, ...]], + transform: F = _default_transform, + ): + """Initialize the StateDictTransform.""" + self.source_key = source_key + self.target_key = target_key + self.transform = transform + + def __call__(self, ctx: TransformCTX) -> TransformCTX: + """Perform the transformation on the given context.""" + source_key = self.source_key + target_key = self.target_key + source_dict, target_dict = ctx.source_state, ctx.target_state + np.set_printoptions(threshold=10) + fn_params = dict(inspect.signature(self.transform).parameters) + fn_params.pop("ctx", None) + matched = False + if isinstance(source_key, (dict, tuple)): + if isinstance(source_key, tuple): + source_key_dict = {param: source_key[i] for i, param in enumerate(fn_params)} + else: + source_key_dict = source_key + source_matches_dict = {k: _match_keys(list(source_dict.keys()), v) for k, v in source_key_dict.items()} + target_matches = _match_keys(list(target_dict.keys()), target_key) + param_names = list(filter(lambda x: x in source_matches_dict, fn_params)) + source_matches = [ + source_matches_dict[v] if source_matches_dict[v].ndim > 0 else [source_matches_dict[v].item()] + for v in param_names + ] + target_matches = [target_matches if target_matches.ndim > 0 else [target_matches.item()]] + for layer_names_group in zip(*(source_matches + target_matches)): + # Wrap in a list if it's a single layer (ie non-expert) + if isinstance(layer_names_group[0], str): + layer_names_group = [[x] for x in layer_names_group] # noqa: PLW2901 + for layer_names in zip(*layer_names_group): + target_dict[layer_names[-1]] = self.call_transform( + ctx, **dict(zip(param_names, [source_dict[x] for x in layer_names[:-1]])) + ) + logger.debug(f"Matched (transform)! {layer_names_group=}") + matched = True + else: + source_keys = list(source_dict.keys()) + target_keys = list(target_dict.keys()) + + source_matches = _match_keys(source_keys, source_key) + if source_matches.size == 1 and source_matches == np.array(None): + raise ValueError(f"No matches found for source key: {source_key}") + + if isinstance(target_key, str): + target_matches = _match_keys(target_keys, target_key) + if target_matches.size == 1 and target_matches == np.array(None): + raise ValueError(f"No matches found for target key: {target_key}") + else: + if isinstance(target_key, dict): + raise ValueError("Target key must be a string or a tuple of strings.") + _matches = [_match_keys(target_keys, key) for key in target_key] + target_matches = np.stack(_matches, axis=-1) + + # Determine if we are dealing with multiple source matches or multiple target matches + multiple_sources = source_matches.ndim >= target_matches.ndim + accepts_var_args = any( + param.kind == param.VAR_POSITIONAL for param in inspect.signature(self.transform).parameters.values() + ) + + if multiple_sources: + for target_index, target_match in np.ndenumerate(target_matches): + try: + source_match = source_matches[target_index] + except IndexError as e: + logger.error(f"Encountered IndexError during transform.\n{source_matches=}\n{target_matches=}") + raise e + if accepts_var_args: + source_values = [source_dict[k] for k in source_match] + target_dict[target_match] = self.call_transform(ctx, *source_values) + else: + _source_match_list = [source_match] if isinstance(source_match, str) else list(source_match) + if len(fn_params) != len(_source_match_list): + raise ValueError( + f"Mismatch between source and target keys: {source_match} vs {target_match}" + ) + + kwargs = {param: source_dict[k] for param, k in zip(fn_params, _source_match_list)} + target_dict[target_match] = self.call_transform(ctx, **kwargs) + logger.debug(f"Matched (multi source)! {target_match=} {source_match=}") + matched = True + else: + for source_index, source_match in np.ndenumerate(source_matches): + target_match = target_matches[source_index] + source_values = ( + [source_dict[source_match]] + if np.isscalar(source_match) + else [source_dict[k] for k in source_match] + ) + if accepts_var_args: + outputs = self.call_transform(ctx, *source_values) + else: + kwargs = dict(zip(fn_params, source_values)) + outputs = self.call_transform(ctx, **kwargs) + + if isinstance(target_match, str): + target_dict[target_match] = outputs + else: + for i, t in enumerate(outputs): + target_dict[target_match[i]] = t + logger.debug(f"Matched (single source)! {target_match=} {source_match=}") + matched = True + if not matched: + logger.warning(f"No matches found for source key: {source_key=} {target_key=}") + return ctx + + def call_transform(self, ctx: TransformCTX, *args, **kwargs): + """Perform transform and check if the given args valid.""" + func_params = inspect.signature(self.transform).parameters + expected_num_args = len([p for p in func_params if p not in ["self", "ctx"]]) + provided_num_args = len(args) + len(kwargs) + accepts_var_args = any(param.kind == param.VAR_POSITIONAL for param in func_params.values()) + + if not accepts_var_args and provided_num_args != expected_num_args: + raise ValueError( + f"Expected {expected_num_args} arguments for the transformation function, but got {provided_num_args}." + ) + + if "ctx" in func_params: + return self.transform(ctx, *args, **kwargs) + + return self.transform(*args, **kwargs) + + +def _match_keys(keys: List[str], pattern: str) -> np.ndarray: + escaped_pattern = "" + i = 0 + wildcard_positions = [] + while i < len(pattern): + if pattern[i : i + 2] == "**": + escaped_pattern += r"(.+)" # Match any characters including dots + wildcard_positions.append("**") + i += 2 + elif pattern[i] == "*": + escaped_pattern += r"([^.]+)" # Match any characters except dots + wildcard_positions.append("*") + i += 1 + else: + if pattern[i] == ".": + escaped_pattern += r"\." # Escape the dot + else: + escaped_pattern += pattern[i] + i += 1 + + regex_pattern = re.compile("^" + escaped_pattern + "$") + num_wildcards = len(wildcard_positions) + wildcard_matches = [[] for _ in range(num_wildcards)] + + for key in filter(lambda x: x is not None, keys): + match = regex_pattern.match(key) + if match: + for i, group in enumerate(match.groups()): + if group not in wildcard_matches[i]: + wildcard_matches[i].append(group) + + # Sort the wildcard matches to maintain consistent ordering + for i in range(len(wildcard_matches)): + wildcard_matches[i].sort(key=lambda x: int(x) if x.isdigit() else x) + + # Determine the shape of the output array based on the unique matches for each wildcard + shape = [len(matches) for matches in wildcard_matches] + + if len(wildcard_matches) == 0: + # If there is no wildcard matches, assuming it is a single match + shape = [1] + # Initialize an empty array with the determined shape + output_array = np.empty(shape, dtype=object) + + # Populate the array with the keys, now that we have the correct shape and ordering + for key in filter(lambda x: x is not None, keys): + match = regex_pattern.match(key) + if match: + # Convert match groups to indices based on their position in wildcard_matches + indices = [wildcard_matches[i].index(group) for i, group in enumerate(match.groups())] + output_array[tuple(indices)] = key # Place the key in the array based on the indices + + return output_array + + +@overload +def state_transform( + source_key: Union[str, Tuple[str, ...], Dict[str, str]], + target_key: Union[str, Tuple[str, ...]], +) -> Callable[[F], StateDictTransform[F]]: ... + + +@overload +def state_transform( + source_key: Union[str, Tuple[str, ...], Dict[str, str]], target_key: Union[str, Tuple[str, ...]], fn: F +) -> StateDictTransform[F]: ... + + +def state_transform( + source_key: Union[str, Tuple[str, ...], Dict[str, str]], + target_key: Union[str, Tuple[str, ...]], + fn: Optional[F] = None, +): + """Create a StateDictTransform instance with specified source and target keys, and a transformation function. + + Args: + source_key: A string, tuple of strings, or a dictionary specifying the keys in the source + state dictionary to match. Wildcards (*) are supported. + target_key: A string or tuple of strings specifying the keys in the target state dictionary + to match. Wildcards (*) are supported. + fn: An optional callable that performs the transformation on matched keys' values. If not + provided, the decorator can be used to wrap a function definition. + + Returns: + ------- + A StateDictTransform instance if `fn` is provided, otherwise returns a decorator that + takes a function and returns a StateDictTransform instance. + + Examples: + -------- + >>> @state_transform( + ... source_key="model.layers.*.self_attn.*_proj.weight", + ... target_key="decoder.layers.*.self_attention.linear_qkv.weight" + ... ) + ... def sum_transform(ctx, *args): + ... return sum(args) + """ + + def wrapper(fn) -> StateDictTransform: + return StateDictTransform(source_key, target_key, fn) + + if fn is None: + return wrapper + + return wrapper(fn) + + +class TransformFns: + """A collection of common functions used in state dict transformation.""" + + @staticmethod + def split_qkv(ctx: TransformCTX, linear_qkv: torch.Tensor): + """Split interleave-concatenated qkv to q, k, v. + + Example: export layer linear_qkv to HF {q|k|v}_proj + """ + target_config = ctx.target.config + + head_num = target_config.num_attention_heads + num_query_groups = target_config.num_key_value_heads + heads_per_group = head_num // num_query_groups + hidden_size = target_config.hidden_size + head_size = hidden_size // head_num + qkv_total_dim = head_num + 2 * num_query_groups + + linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, -1]) + # when converting base model (linear_qkv), hidden size = megatron_config.hidden_size + # when converting lora (linear_qkv.adapter.linear_out), hidden size = lora_r + hidden_size = linear_qkv.size(-1) + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + q_proj = linear_qkv[q_slice].reshape(-1, hidden_size).cpu() + k_proj = linear_qkv[k_slice].reshape(-1, hidden_size).cpu() + v_proj = linear_qkv[v_slice].reshape(-1, hidden_size).cpu() + + return q_proj, k_proj, v_proj + + @staticmethod + def split_qkv_bias(ctx: TransformCTX, qkv_bias: torch.Tensor): + """Split interleave-concatenated qkv bias to separate q, k, v bias. + + Example: export layer linear_qkv bias to HF {q|k|v}_proj bias + """ + megatron_config = ctx.source.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + head_size = megatron_config.kv_channels + qkv_total_dim = head_num + 2 * num_query_groups + + qkv_bias = qkv_bias.reshape([qkv_total_dim, head_size]) + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + q_bias = qkv_bias[q_slice].reshape(-1).cpu() + k_bias = qkv_bias[k_slice].reshape(-1).cpu() + v_bias = qkv_bias[v_slice].reshape(-1).cpu() + + return q_bias, k_bias, v_bias + + @staticmethod + def merge_qkv_concat(ctx: TransformCTX, qkv: torch.Tensor): + """Merge naively concatenated q, k, v to interleave-concatenated qkv. + + Example: import HF qkv to layer linear_qkv + """ + megatron_config = ctx.target.config + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + head_size = megatron_config.kv_channels + q, k, v = qkv.split([head_num * head_size, num_query_groups * head_size, num_query_groups * head_size], dim=0) + return TransformFns.merge_qkv(ctx, q, k, v) + + @staticmethod + def merge_qkv(ctx: TransformCTX, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + """Merge q, k, v to interleave-concatenated qkv. + + Example: import HF {q|k|v}_proj to layer linear_qkv + """ + target_config = ctx.target.config + + head_num = target_config.num_attention_heads + num_query_groups = target_config.num_key_value_heads + heads_per_group = head_num // num_query_groups + hidden_size = target_config.hidden_size + head_size = hidden_size // head_num + old_tensor_shape = q.size() + new_q_tensor_shape = (head_num, head_size, *old_tensor_shape[1:]) + new_kv_tensor_shape = (num_query_groups, head_size, *old_tensor_shape[1:]) + + q = q.view(*new_q_tensor_shape) + k = k.view(*new_kv_tensor_shape) + v = v.view(*new_kv_tensor_shape) + + qkv_weights_l = [] + for i in range(num_query_groups): + qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :]) + qkv_weights_l.append(k[i : i + 1, :, :]) + qkv_weights_l.append(v[i : i + 1, :, :]) + qkv_weights = torch.cat(qkv_weights_l) + assert qkv_weights.ndim == 3, qkv_weights.shape + assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape + assert qkv_weights.shape[1] == head_size, qkv_weights.shape + assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape + + qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) + + return qkv_weights + + @staticmethod + def merge_qkv_bias_concat(ctx: TransformCTX, qkv_bias: torch.Tensor): + """Merge naively concatenated q, k, v bias to interleave-concatenated qkv bias. + + Example: import HF qkv bias to layer linear_qkv bias + """ + megatron_config = ctx.target.config + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + head_size = megatron_config.kv_channels + qb, kb, vb = qkv_bias.split( + [head_num * head_size, num_query_groups * head_size, num_query_groups * head_size], dim=0 + ) + return TransformFns.merge_qkv_bias(ctx, qb, kb, vb) + + @staticmethod + def merge_qkv_bias(ctx: TransformCTX, qb: torch.Tensor, kb: torch.Tensor, vb: torch.Tensor): + """Merge q, k, v bias to interleave-concatenated qkv bias. + + Example: import HF {q|k|v}_proj bias to layer linear_qkv bias + """ + megatron_config = ctx.target.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + head_size = megatron_config.kv_channels + + new_q_tensor_shape = (head_num, head_size) + new_kv_tensor_shape = (num_query_groups, head_size) + + qb = qb.view(*new_q_tensor_shape) + kb = kb.view(*new_kv_tensor_shape) + vb = vb.view(*new_kv_tensor_shape) + + qkv_bias = torch.empty((0, head_size)).type_as(qb) + for i in range(num_query_groups): + qkv_bias = torch.cat((qkv_bias, qb[i * heads_per_group : (i + 1) * heads_per_group, :])) + qkv_bias = torch.cat((qkv_bias, kb[i : i + 1, :])) + qkv_bias = torch.cat((qkv_bias, vb[i : i + 1, :])) + qkv_bias = qkv_bias.reshape( + [ + head_size * (head_num + 2 * num_query_groups), + ] + ) + return qkv_bias + + @staticmethod + def merge_fc1(gate: torch.Tensor, up: torch.Tensor): + """Merge gate and up proj into concatenated fc1. + + Example: import HF {gate|up}_proj to layer linear_fc1 + """ + return torch.cat((gate, up), dim=0) + + @staticmethod + def split_fc1(linear_fc1: torch.Tensor): + """Split concatenated fc1 to gate and up proj. + + Example: export layer linear_fc1 to HF {gate|up}_proj + """ + gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0) + return gate_proj, up_proj + + @staticmethod + def duplicate2(param: torch.Tensor): + """Duplicate the source parameter to two target parameters. + + Example: export Performant LoRA linear_fc1.adapter.linear_in to HF {gate|up}_proj.lora_A + """ + return param, param + + @staticmethod + def duplicate3(param: torch.Tensor): + """Duplicate the source parameter to three target parameters. + + Example: export Performant LoRA linear_qkv.adapter.linear_in to HF {q|k|v}_proj.lora_A + """ + return param, param, param + + @staticmethod + def prune_padding(ctx: TransformCTX, embedding: torch.Tensor): + """Prune the embedding size to vocab size. + + Example: export embedding/output layer to HF with non-padded vocab size + """ + megatron_config = ctx.target.config + return embedding[: megatron_config.vocab_size, :] + + +def extract_dtypes(ckpt): + """Extract dtype from the input iterator. + + ckpt can be module.named_parameters or module.state_dict().items() + """ + dtypes = {} + for key, val in ckpt: + if hasattr(val, "dtype"): + dtypes[key] = val.dtype + elif hasattr(val, "data") and hasattr(val.data, "dtype"): + # if it's ShardedTensor populated with data. + dtypes[key] = val.data.dtype + return dtypes diff --git a/bionemo-recipes/recipes/vllm_inference/esm2/tests/conftest.py b/bionemo-recipes/recipes/vllm_inference/esm2/tests/conftest.py new file mode 100644 index 0000000000..69652d6457 --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/esm2/tests/conftest.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from pathlib import Path + + +sys.path.append(Path(__file__).parent.parent.as_posix()) diff --git a/bionemo-recipes/recipes/vllm_inference/esm2/tests/test_vllm.py b/bionemo-recipes/recipes/vllm_inference/esm2/tests/test_vllm.py new file mode 100644 index 0000000000..e42dd451e0 --- /dev/null +++ b/bionemo-recipes/recipes/vllm_inference/esm2/tests/test_vllm.py @@ -0,0 +1,151 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Golden-value tests for ESM2 vLLM compatibility. + +Performs a fresh facebook -> TE export, then cross-validates embeddings across +vLLM, HuggingFace (exported checkpoint), and HuggingFace (nvidia Hub reference). + +vLLM's pooling runner returns last-token, L2-normalised embeddings by default, +so the HuggingFace paths replicate that post-processing for comparison. +""" + +from pathlib import Path + +import numpy as np +import pytest +import torch +from transformers import AutoModel, AutoTokenizer + + +try: + from vllm import LLM + + _VLLM_AVAILABLE = True +except ImportError: + _VLLM_AVAILABLE = False + +from export import export_hf_checkpoint + + +EXPORT_TAG = "esm2_t6_8M_UR50D" +REFERENCE_MODEL_ID = "nvidia/esm2_t6_8M_UR50D" +ESM2_MODEL_DIR = Path(__file__).resolve().parent.parent + +SEQUENCES = [ + "LKGHAMCLGCLHMLMCGLLAGAMCGLMKLLKCCGKCLMHLMKAMLGLKCACHHHHLLLHACAAKKLCLGAKLAMGLKLLGAHGKGLKMACGHHMLHLHMH", + "CLLCCMHMHAHHCHGHGHKCKCLMMGMALMCAGCCACGMKGGCHCCLLAHCAHAKAGKGKCKLMCKKKHGLHAGLHAMLLCHLGLGCGHHHKKCKKHKCA", +] + + +def _last_token_l2(hidden_state: torch.Tensor) -> np.ndarray: + """Extract last-token hidden state and L2-normalise (matches vLLM pooling defaults).""" + vec = hidden_state[0, -1, :].cpu().float().numpy() + norm = np.linalg.norm(vec) + if norm > 1e-9: + vec = vec / norm + return vec + + +def _hf_embed(model_id: str, sequences: list[str], dtype=torch.float32) -> np.ndarray: + """Run HuggingFace inference and return last-token L2-normalised embeddings.""" + model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to("cuda", dtype=dtype).eval() + tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) + + vecs = [] + with torch.no_grad(): + for seq in sequences: + inputs = tokenizer(seq, return_tensors="pt") + inputs = {k: v.to("cuda") for k, v in inputs.items()} + out = model(**inputs) + vecs.append(_last_token_l2(out.last_hidden_state)) + + del model, tokenizer + torch.cuda.empty_cache() + return np.stack(vecs) + + +def _vllm_embed(model_id: str, sequences: list[str]) -> np.ndarray: + """Run vLLM pooling inference and return embeddings.""" + engine = LLM( + model=model_id, + runner="pooling", + trust_remote_code=True, + dtype="float32", + enforce_eager=True, + max_num_batched_tokens=1026, + ) + outputs = engine.embed(sequences) + + vecs = [] + for output in outputs: + emb = output.outputs.embedding + if isinstance(emb, list): + emb = np.array(emb) + vecs.append(emb) + + del engine + return np.stack(vecs) + + +# ---- Fixtures ---- + + +@pytest.fixture(scope="session") +def exported_checkpoint(tmp_path_factory): + """Fresh facebook -> TE export. Session-scoped so it runs once.""" + export_dir = tmp_path_factory.mktemp("vllm_export") + export_hf_checkpoint(EXPORT_TAG, export_dir) + return str(export_dir / EXPORT_TAG) + + +@pytest.fixture(scope="session") +def vllm_embeddings(exported_checkpoint): + """Embeddings from vLLM pooling runner on the exported checkpoint.""" + if not _VLLM_AVAILABLE: + pytest.skip("vllm not installed") + return _vllm_embed(exported_checkpoint, SEQUENCES) + + +@pytest.fixture(scope="session") +def hf_exported_embeddings(exported_checkpoint): + """Embeddings from HuggingFace on the exported checkpoint.""" + return _hf_embed(exported_checkpoint, SEQUENCES) + + +@pytest.fixture(scope="session") +def hf_reference_embeddings(): + """Embeddings from HuggingFace on the nvidia Hub model (ground truth).""" + return _hf_embed(REFERENCE_MODEL_ID, SEQUENCES) + + +# ---- Tests ---- + + +@pytest.mark.skipif(not _VLLM_AVAILABLE, reason="vllm not installed") +def test_vllm_vs_hf_exported(vllm_embeddings, hf_exported_embeddings): + """vLLM and native HuggingFace on the same exported checkpoint must match.""" + np.testing.assert_allclose(vllm_embeddings, hf_exported_embeddings, atol=2e-4) + + +@pytest.mark.skipif(not _VLLM_AVAILABLE, reason="vllm not installed") +def test_vllm_vs_hf_reference(vllm_embeddings, hf_reference_embeddings): + """vLLM on exported checkpoint must match HuggingFace on nvidia Hub model.""" + np.testing.assert_allclose(vllm_embeddings, hf_reference_embeddings, atol=2e-4) + + +def test_hf_exported_vs_hf_reference(hf_exported_embeddings, hf_reference_embeddings): + """Our exported checkpoint must produce identical results to the nvidia Hub model.""" + np.testing.assert_array_equal(hf_exported_embeddings, hf_reference_embeddings) diff --git a/ci/scripts/check_copied_files.py b/ci/scripts/check_copied_files.py index 8c952432df..b97ebd4ed3 100755 --- a/ci/scripts/check_copied_files.py +++ b/ci/scripts/check_copied_files.py @@ -34,6 +34,7 @@ "bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py", "bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py", "bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py", + "bionemo-recipes/recipes/vllm_inference/esm2/modeling_esm_te.py", ], "bionemo-recipes/models/esm2/collator.py": [ "bionemo-recipes/models/llama3/collator.py", @@ -46,6 +47,7 @@ "bionemo-recipes/models/amplify/src/amplify/state.py", "bionemo-recipes/models/llama3/state.py", "bionemo-recipes/models/mixtral/state.py", + "bionemo-recipes/recipes/vllm_inference/esm2/state.py", ], "bionemo-recipes/models/llama3/modeling_llama_te.py": [ "bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py", @@ -53,6 +55,15 @@ "bionemo-recipes/models/llama3/nucleotide_fast_tokenizer": [ "bionemo-recipes/recipes/llama3_native_te/tokenizers/nucleotide_fast_tokenizer", ], + "bionemo-recipes/models/esm2/convert.py": [ + "bionemo-recipes/recipes/vllm_inference/esm2/convert.py", + ], + "bionemo-recipes/models/esm2/export.py": [ + "bionemo-recipes/recipes/vllm_inference/esm2/export.py", + ], + "bionemo-recipes/models/esm2/esm_fast_tokenizer": [ + "bionemo-recipes/recipes/vllm_inference/esm2/esm_fast_tokenizer", + ], # Common test library - synced between models "bionemo-recipes/models/esm2/tests/common": [ "bionemo-recipes/models/llama3/tests/common",