Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,33 @@ repos:
- id: ruff
# 1. Attempt to automatically fix any lint issues.
args: ["--fix"]
# Exclude check_copied_files destinations; they are verbatim copies
# of source files and must not be reformatted independently.
exclude: |
(?x)^(
bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv\.py|
bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv\.py|
bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv\.py|
bionemo-recipes/recipes/vllm_inference/esm2/(modeling_esm_te|convert|export|state)\.py|
bionemo-recipes/recipes/(esm2_native_te|llama3_native_te|esm2_peft_te)/collator\.py|
bionemo-recipes/recipes/llama3_native_te/modeling_llama_te\.py|
bionemo-recipes/models/(llama3|mixtral)/collator\.py|
bionemo-recipes/models/(amplify/src/amplify|llama3|mixtral)/state\.py|
bionemo-recipes/models/(llama3|mixtral)/tests/common/
)$
- id: ruff-format
exclude: |
(?x)^(
bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv\.py|
bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv\.py|
bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv\.py|
bionemo-recipes/recipes/vllm_inference/esm2/(modeling_esm_te|convert|export|state)\.py|
bionemo-recipes/recipes/(esm2_native_te|llama3_native_te|esm2_peft_te)/collator\.py|
bionemo-recipes/recipes/llama3_native_te/modeling_llama_te\.py|
bionemo-recipes/models/(llama3|mixtral)/collator\.py|
bionemo-recipes/models/(amplify/src/amplify|llama3|mixtral)/state\.py|
bionemo-recipes/models/(llama3|mixtral)/tests/common/
)$
- repo: https://github.com/executablebooks/mdformat
rev: 0.7.22 # Use the latest stable version
hooks:
Expand Down
2 changes: 2 additions & 0 deletions bionemo-recipes/models/esm2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ Training recipes are available in the `bionemo-recipes/recipes/` directory:
loop.
- **[esm2_accelerate_te](../../recipes/esm2_accelerate_te/)** - Trains the model using HuggingFace
[Accelerate](https://huggingface.co/docs/accelerate/index).
- **[vllm_inference/esm2](../../recipes/vllm_inference/esm2/)** - Demonstrates inference with
[vLLM](https://github.com/vllm-project/vllm).

## Converting Between Model Formats

Expand Down
36 changes: 18 additions & 18 deletions bionemo-recipes/models/esm2/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,18 @@


mapping = {
"esm.encoder.layer.*.attention.output.dense.weight": "esm.encoder.layers.*.self_attention.proj.weight",
"esm.encoder.layer.*.attention.output.dense.bias": "esm.encoder.layers.*.self_attention.proj.bias",
"esm.encoder.layer.*.attention.LayerNorm.weight": "esm.encoder.layers.*.self_attention.layernorm_qkv.layer_norm_weight",
"esm.encoder.layer.*.attention.LayerNorm.bias": "esm.encoder.layers.*.self_attention.layernorm_qkv.layer_norm_bias",
"esm.encoder.layer.*.intermediate.dense.weight": "esm.encoder.layers.*.layernorm_mlp.fc1_weight",
"esm.encoder.layer.*.intermediate.dense.bias": "esm.encoder.layers.*.layernorm_mlp.fc1_bias",
"esm.encoder.layer.*.output.dense.weight": "esm.encoder.layers.*.layernorm_mlp.fc2_weight",
"esm.encoder.layer.*.output.dense.bias": "esm.encoder.layers.*.layernorm_mlp.fc2_bias",
"esm.encoder.layer.*.LayerNorm.weight": "esm.encoder.layers.*.layernorm_mlp.layer_norm_weight",
"esm.encoder.layer.*.LayerNorm.bias": "esm.encoder.layers.*.layernorm_mlp.layer_norm_bias",
"esm.encoder.emb_layer_norm_after.weight": "esm.encoder.emb_layer_norm_after.weight",
"esm.encoder.emb_layer_norm_after.bias": "esm.encoder.emb_layer_norm_after.bias",
"esm.encoder.layer.*.attention.output.dense.weight": "model.encoder.layers.*.self_attention.proj.weight",
"esm.encoder.layer.*.attention.output.dense.bias": "model.encoder.layers.*.self_attention.proj.bias",
"esm.encoder.layer.*.attention.LayerNorm.weight": "model.encoder.layers.*.self_attention.layernorm_qkv.layer_norm_weight",
"esm.encoder.layer.*.attention.LayerNorm.bias": "model.encoder.layers.*.self_attention.layernorm_qkv.layer_norm_bias",
"esm.encoder.layer.*.intermediate.dense.weight": "model.encoder.layers.*.layernorm_mlp.fc1_weight",
"esm.encoder.layer.*.intermediate.dense.bias": "model.encoder.layers.*.layernorm_mlp.fc1_bias",
"esm.encoder.layer.*.output.dense.weight": "model.encoder.layers.*.layernorm_mlp.fc2_weight",
"esm.encoder.layer.*.output.dense.bias": "model.encoder.layers.*.layernorm_mlp.fc2_bias",
"esm.encoder.layer.*.LayerNorm.weight": "model.encoder.layers.*.layernorm_mlp.layer_norm_weight",
"esm.encoder.layer.*.LayerNorm.bias": "model.encoder.layers.*.layernorm_mlp.layer_norm_bias",
"esm.encoder.emb_layer_norm_after.weight": "model.encoder.emb_layer_norm_after.weight",
"esm.encoder.emb_layer_norm_after.bias": "model.encoder.emb_layer_norm_after.bias",
"lm_head.dense.weight": "lm_head.dense.weight",
"lm_head.dense.bias": "lm_head.dense.bias",
"lm_head.layer_norm.weight": "lm_head.decoder.layer_norm_weight",
Expand Down Expand Up @@ -135,7 +135,7 @@ def convert_esm_te_to_hf(model_te: nn.Module, **config_kwargs) -> nn.Module:
"esm.encoder.layer.*.attention.self.key.weight",
"esm.encoder.layer.*.attention.self.value.weight",
),
target_key="esm.encoder.layers.*.self_attention.layernorm_qkv.weight",
target_key="model.encoder.layers.*.self_attention.layernorm_qkv.weight",
)
def _pack_qkv_weight(ctx: state.TransformCTX, query, key, value):
"""Pack separate Q, K, V weight tensors into a single interleaved QKV weight tensor."""
Expand All @@ -157,7 +157,7 @@ def _pack_qkv_weight(ctx: state.TransformCTX, query, key, value):
"esm.encoder.layer.*.attention.self.key.bias",
"esm.encoder.layer.*.attention.self.value.bias",
),
target_key="esm.encoder.layers.*.self_attention.layernorm_qkv.bias",
target_key="model.encoder.layers.*.self_attention.layernorm_qkv.bias",
)
def _pack_qkv_bias(ctx: state.TransformCTX, query, key, value):
"""Pack separate Q, K, V bias tensors into a single interleaved QKV bias tensor."""
Expand All @@ -174,7 +174,7 @@ def _pack_qkv_bias(ctx: state.TransformCTX, query, key, value):


@state.state_transform(
source_key="esm.encoder.layers.*.self_attention.layernorm_qkv.weight",
source_key="model.encoder.layers.*.self_attention.layernorm_qkv.weight",
target_key=(
"esm.encoder.layer.*.attention.self.query.weight",
"esm.encoder.layer.*.attention.self.key.weight",
Expand Down Expand Up @@ -203,7 +203,7 @@ def _unpack_qkv_weight(ctx: state.TransformCTX, qkv_weight):


@state.state_transform(
source_key="esm.encoder.layers.*.self_attention.layernorm_qkv.bias",
source_key="model.encoder.layers.*.self_attention.layernorm_qkv.bias",
target_key=(
"esm.encoder.layer.*.attention.self.query.bias",
"esm.encoder.layer.*.attention.self.key.bias",
Expand Down Expand Up @@ -248,7 +248,7 @@ def _pad_weights(ctx: state.TransformCTX, source_embed):

_pad_embeddings = state.state_transform(
source_key="esm.embeddings.word_embeddings.weight",
target_key="esm.embeddings.word_embeddings.weight",
target_key="model.embeddings.word_embeddings.weight",
)(_pad_weights)

_pad_decoder_weights = state.state_transform(
Expand All @@ -257,7 +257,7 @@ def _pad_weights(ctx: state.TransformCTX, source_embed):
)(_pad_weights)

_unpad_embeddings = state.state_transform(
source_key="esm.embeddings.word_embeddings.weight",
source_key="model.embeddings.word_embeddings.weight",
target_key="esm.embeddings.word_embeddings.weight",
)(_unpad_weights)

Expand Down
6 changes: 5 additions & 1 deletion bionemo-recipes/models/esm2/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,11 @@ def export_hf_checkpoint(tag: str, export_path: Path):
model_hf_masked_lm = AutoModelForMaskedLM.from_pretrained(f"facebook/{tag}")
model_hf = AutoModel.from_pretrained(f"facebook/{tag}")
model_hf_masked_lm.esm.pooler = model_hf.pooler
model_te = convert_esm_hf_to_te(model_hf_masked_lm)

# Export without vocab padding so the checkpoint stores embeddings at the real
# vocab_size. This avoids shape-mismatch errors in vLLM's VocabParallelEmbedding,
# which expects vocab_size-shaped weights.
model_te = convert_esm_hf_to_te(model_hf_masked_lm, padded_vocab_size=None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[not blocking]: Okay made some changes to convert_esm_hf_to_te

model_te.save_pretrained(export_path / tag)

tokenizer = AutoTokenizer.from_pretrained("esm_fast_tokenizer") # Use our PreTrainedTokenizerFast implementation.
Expand Down
47 changes: 30 additions & 17 deletions bionemo-recipes/models/esm2/modeling_esm_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
max_seq_length: Optional[int] = None,
padded_vocab_size: Optional[int] = 64,
attn_mask_type: str = "padding",
add_pooling_layer: bool = False,
layer_precision: list[str | None] | None = None,
**kwargs,
):
Expand Down Expand Up @@ -103,6 +104,9 @@ def __init__(
padded_vocab_size: The padded vocabulary size to support FP8. If not provided, defaults
to vocab_size. Must be greater than or equal to vocab_size.
attn_mask_type: The type of attention mask to use.
add_pooling_layer: Whether the base model should include a pooling layer.
Defaults to ``False`` because exported checkpoints do not contain pooler
weights. Set to ``True`` only if you have a checkpoint with pooler weights.
layer_precision: Per-layer quantization precision, a list of length ``num_hidden_layers``
where each element is ``"fp8"``, ``"fp4"``, or ``None`` (BF16 fallback). ``None``
(the default) means no quantization is configured.
Expand All @@ -117,6 +121,7 @@ def __init__(
self.micro_batch_size = micro_batch_size
self.max_seq_length = max_seq_length
self.attn_mask_type = attn_mask_type
self.add_pooling_layer = add_pooling_layer
self.layer_precision = layer_precision

# Set padded_vocab_size with default fallback to vocab_size
Expand Down Expand Up @@ -289,7 +294,7 @@ class NVEsmPreTrainedModel(EsmPreTrainedModel):
"""An abstract class to handle weights initialization and pretrained model loading."""

config_class = NVEsmConfig
base_model_prefix = "esm"
base_model_prefix = "model"
supports_gradient_checkpointing = False
accepts_loss_kwargs = False
_no_split_modules = (
Expand All @@ -305,11 +310,11 @@ def init_empty_weights(self):
if hasattr(module, "reset_parameters"):
module.reset_parameters()

# The esm.embeddings layer is the only non-TE layer in this model we need to deal with. We use
# The embeddings layer is the only non-TE layer in this model we need to deal with. We use
# `model._init_weights` rather than `reset_parameters` to ensure we honor the original config standard
# deviation.
self.esm.embeddings.word_embeddings.to_empty(device="cuda")
self.esm.embeddings.apply(self._init_weights)
# deviation. self.base_model resolves to self.model for wrapper classes or self for NVEsmModel.
self.base_model.embeddings.word_embeddings.to_empty(device="cuda")
self.base_model.embeddings.apply(self._init_weights)

# Meta-device init seems to break weight tying, so we re-tie the weights here.
self.tie_weights()
Expand All @@ -334,14 +339,16 @@ def _init_weights(self, module):
super()._init_weights(module)

def state_dict(self, *args, **kwargs):
"""Override state_dict to filter out TransformerEngine's _extra_state keys.
"""Override state_dict to filter out non-loadable keys.

TransformerEngine layers add _extra_state attributes that are not compatible with HuggingFace v5 model loading.
These are filtered out to ensure checkpoints can be loaded with from_pretrained().
Filters out:
- ``_extra_state`` keys: TransformerEngine-specific, not loadable by HuggingFace v5.
- ``.inv_freq`` buffers: Computed at init time by RotaryPositionEmbedding, not needed
in the checkpoint and not loadable by vLLM's AutoWeightsLoader (which only iterates
over ``named_parameters``, not ``named_buffers``).
"""
state_dict = super().state_dict(*args, **kwargs)
# Filter out _extra_state keys which are TransformerEngine-specific and not loadable
return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state")}
return {k: v for k, v in state_dict.items() if not k.endswith("_extra_state") and not k.endswith(".inv_freq")}


class NVEsmModel(NVEsmPreTrainedModel):
Expand All @@ -350,16 +357,20 @@ class NVEsmModel(NVEsmPreTrainedModel):
This model uses NVDIA's TransformerEngine to optimize attention layer training and inference.
"""

def __init__(self, config: NVEsmConfig, add_pooling_layer: bool = True):
def __init__(self, config: NVEsmConfig, add_pooling_layer: Optional[bool] = None):
"""Initialize a NVEsmModel.

Args:
config (NVEsmConfig): The configuration of the model.
add_pooling_layer (bool): Whether to add a pooling layer.
add_pooling_layer (bool): Whether to add a pooling layer. If ``None``,
reads ``config.add_pooling_layer`` (defaults to ``True``).
"""
super().__init__(config)
self.config = config

if add_pooling_layer is None:
add_pooling_layer = getattr(config, "add_pooling_layer", True)

# Ensure pad_token_id is set properly, defaulting to 0 if not specified
if not hasattr(config, "pad_token_id") or config.pad_token_id is None:
config.pad_token_id = 0
Expand Down Expand Up @@ -449,7 +460,9 @@ def forward(
class NVEsmForMaskedLM(NVEsmPreTrainedModel):
"""NVEsmForMaskedLM is a TransformerEngine-optimized ESM model for masked language modeling."""

_tied_weights_keys: ClassVar[dict[str, str]] = {"lm_head.decoder.weight": "esm.embeddings.word_embeddings.weight"}
_tied_weights_keys: ClassVar[dict[str, str]] = {
"lm_head.decoder.weight": "model.embeddings.word_embeddings.weight"
}
_do_not_quantize = ("lm_head.dense", "lm_head.decoder") # Flag for testing that these layers are not quantized.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're deleting _do_not_quantize? we need that


def __init__(self, config: NVEsmConfig):
Expand All @@ -466,7 +479,7 @@ def __init__(self, config: NVEsmConfig):
"bi-directional self-attention."
)

self.esm = NVEsmModel(config, add_pooling_layer=False)
self.model = NVEsmModel(config, add_pooling_layer=False)
self.lm_head = NVEsmLMHead(config)

self.post_init()
Expand Down Expand Up @@ -501,7 +514,7 @@ def forward(
Returns:
MaskedLMOutput: The output of the model.
"""
outputs = self.esm(
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
Expand Down Expand Up @@ -719,7 +732,7 @@ def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels

self.esm = NVEsmModel(config, add_pooling_layer=False)
self.model = NVEsmModel(config, add_pooling_layer=False)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = transformer_engine.pytorch.Linear(
config.hidden_size,
Expand All @@ -745,7 +758,7 @@ def forward(
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
"""
outputs = self.esm(
outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
Expand Down
10 changes: 5 additions & 5 deletions bionemo-recipes/models/esm2/tests/test_cp_bshd.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ def test_context_parallel_equivalence_2process():

# Sample gradients from a few layers for comparison
sample_layers = [
model.esm.encoder.layers[0].self_attention.core_attention,
model.esm.encoder.layers[0].self_attention.layernorm_qkv,
model.model.encoder.layers[0].self_attention.core_attention,
model.model.encoder.layers[0].self_attention.layernorm_qkv,
]

# Now grab the gradients from the sample layers
Expand Down Expand Up @@ -262,7 +262,7 @@ def test_context_parallel_equivalence_2process():
cp_world_size = torch.distributed.get_world_size(group=cp_group)

# Set up context parallelism for each layer
for i, transformer_layer in enumerate(model.module.esm.encoder.layers):
for i, transformer_layer in enumerate(model.module.model.encoder.layers):
transformer_layer.set_context_parallel_group(
cp_group, torch.distributed.get_process_group_ranks(device_mesh["cp"].get_group()), torch.cuda.Stream()
)
Expand Down Expand Up @@ -347,8 +347,8 @@ def test_context_parallel_equivalence_2process():
# Capture gradients from the same layers in the CP model
# Note: DDP wraps the model with 'module.' prefix
sample_layers_cp = [
model.module.esm.encoder.layers[0].self_attention.core_attention,
model.module.esm.encoder.layers[0].self_attention.layernorm_qkv,
model.module.model.encoder.layers[0].self_attention.core_attention,
model.module.model.encoder.layers[0].self_attention.layernorm_qkv,
]

gradients_cp = {}
Expand Down
10 changes: 5 additions & 5 deletions bionemo-recipes/models/esm2/tests/test_cp_thd.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ def test_context_parallel_equivalence_2process():

# Sample gradients from a few layers for comparison
sample_layers = [
model.esm.encoder.layers[0].self_attention.core_attention,
model.esm.encoder.layers[0].self_attention.layernorm_qkv,
model.model.encoder.layers[0].self_attention.core_attention,
model.model.encoder.layers[0].self_attention.layernorm_qkv,
]

# Now grab the gradients from the sample layers
Expand Down Expand Up @@ -253,7 +253,7 @@ def test_context_parallel_equivalence_2process():
cp_world_size = torch.distributed.get_world_size(group=cp_group)

# Set up context parallelism for each layer
for i, transformer_layer in enumerate(model.module.esm.encoder.layers):
for i, transformer_layer in enumerate(model.module.model.encoder.layers):
transformer_layer.set_context_parallel_group(
cp_group, torch.distributed.get_process_group_ranks(device_mesh["cp"].get_group()), torch.cuda.Stream()
)
Expand Down Expand Up @@ -344,8 +344,8 @@ def test_context_parallel_equivalence_2process():
# Capture gradients from the same layers in the CP model
# Note: DDP wraps the model with 'module.' prefix
sample_layers_cp = [
model.module.esm.encoder.layers[0].self_attention.core_attention,
model.module.esm.encoder.layers[0].self_attention.layernorm_qkv,
model.module.model.encoder.layers[0].self_attention.core_attention,
model.module.model.encoder.layers[0].self_attention.layernorm_qkv,
]

gradients_cp = {}
Expand Down
4 changes: 2 additions & 2 deletions bionemo-recipes/models/esm2/tests/test_distributed_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def is_main_process(self) -> bool:
model = NVEsmForMaskedLM(config)

if args.strategy is Strategy.FSDP2:
for layer in model.esm.encoder.layers:
for layer in model.model.encoder.layers:
fully_shard(layer, mesh=device_mesh["dp"])
fully_shard(model, mesh=device_mesh["dp"])
model.to(device)
Expand Down Expand Up @@ -199,7 +199,7 @@ def is_main_process(self) -> bool:
)

# Attach FP8 recipes to the encoder (layer precision is already on config).
encoder = model.module.esm.encoder if args.strategy is Strategy.DDP else model.esm.encoder
encoder = model.module.model.encoder if args.strategy is Strategy.DDP else model.model.encoder
encoder.set_recipes(fp8_recipe=fp8_recipe, fp4_recipe=None)

model.train()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def run_forward_backward(use_te: bool, strategy: Strategy, input_data: dict, dis
revision="c731040f",
)
model = NVEsmForMaskedLM(config)
transformer_layers = model.esm.encoder.layers
transformer_layers = model.model.encoder.layers
else:
model = AutoModelForMaskedLM.from_pretrained(
"facebook/esm2_t6_8M_UR50D",
Expand Down
Loading
Loading