Skip to content

Commit d208023

Browse files
authored
[BIONEMO-3530] Fix weight initialization in ESM2 (#1406)
### Description This PR fixes a bug where when instantiating a model `from_pretrained` the layers that are not part of the pretrained model are not being initialized. ### Type of changes <!-- Mark the relevant option with an [x] --> - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Refactor - [ ] Documentation update - [ ] Other (please describe): ### CI Pipeline Configuration Configure CI behavior by applying the relevant labels. By default, only basic unit tests are run. - [ciflow:skip](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:skip) - Skip all CI tests for this PR - [ciflow:notebooks](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:notebooks) - Run Jupyter notebooks execution tests for bionemo2 - [ciflow:slow](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:slow) - Run slow single GPU integration tests marked as @pytest.mark.slow for bionemo2 - [ciflow:all](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all) - Run all tests (unit tests, slow tests, and notebooks) for bionemo2. This label can be used to enforce running tests for all bionemo2. - [ciflow:all-recipes](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/main/contributing/contributing.md#ciflow:all-recipes) - Run tests for all recipes (under bionemo-recipes). This label can be used to enforce running tests for all recipes. Unit tests marked as `@pytest.mark.multi_gpu` or `@pytest.mark.distributed` are not run in the PR pipeline. For more details, see [CONTRIBUTING](CONTRIBUTING.md) > [!NOTE] > By default, only basic unit tests are run. Add appropriate labels to enable an additional test coverage. #### Authorizing CI Runs We use [copy-pr-bot](https://docs.gha-runners.nvidia.com/apps/copy-pr-bot/#automation) to manage authorization of CI runs on NVIDIA's compute resources. - If a pull request is opened by a trusted user and contains only trusted changes, the pull request's code will automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123) - If a pull request is opened by an untrusted user or contains untrusted changes, an NVIDIA org member must leave an `/ok to test` comment on the pull request to trigger CI. This will need to be done for each new commit. ### Pre-submit Checklist <!--- Ensure all items are completed before submitting --> - [x] I have tested these changes locally - [x] I have updated the documentation accordingly - [x] I have added/updated tests as needed - [x] All existing tests pass successfully Signed-off-by: Bruno Alvisio <balvisio@nvidia.com>
1 parent f318fe2 commit d208023

File tree

5 files changed

+53
-5
lines changed

5 files changed

+53
-5
lines changed

bionemo-recipes/models/esm2/src/esm/modeling_esm_te.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,11 @@ def init_empty_weights(self):
284284
# Meta-device init seems to break weight tying, so we re-tie the weights here.
285285
self.tie_weights()
286286

287+
@classmethod
288+
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
289+
"""Override the default get_init_context method to allow for fp8 model initialization."""
290+
return []
291+
287292

288293
class NVEsmModel(NVEsmPreTrainedModel):
289294
"""The ESM Encoder-only protein language model.

bionemo-recipes/models/esm2/tests/test_meta_device_init.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from transformer_engine.pytorch.tensor import QuantizedTensor
3636
from transformers import AutoConfig, set_seed
3737

38-
from esm.modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM
38+
from esm.modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM, NVEsmForTokenClassification
3939

4040

4141
requires_multi_gpu = pytest.mark.skipif(
@@ -44,6 +44,10 @@
4444
)
4545

4646

47+
def msg(x):
48+
return f"Mismatch in module {name}: {x}"
49+
50+
4751
def verify_model_parameters_initialized_correctly(
4852
model: NVEsmForMaskedLM, atol=1e-3, rtol=1e-4, should_be_fp8: bool = False
4953
):
@@ -53,10 +57,6 @@ def verify_model_parameters_initialized_correctly(
5357
assert str(parameter.device).startswith("cuda"), f"Parameter {name} is not on the cuda device"
5458

5559
for name, module in model.named_modules():
56-
57-
def msg(x):
58-
return f"Mismatch in module {name}: {x}"
59-
6060
if isinstance(module, torch.nn.Embedding):
6161
torch.testing.assert_close(module.weight.mean().item(), 0.0, atol=atol, rtol=rtol, msg=msg)
6262
torch.testing.assert_close(
@@ -118,6 +118,22 @@ def msg(x):
118118
torch.testing.assert_close(module.inv_freq, expected_inv_freq, msg=msg)
119119

120120

121+
def verify_pretrained_model_sanity(model: NVEsmForTokenClassification, atol=1e-3, rtol=1e-4):
122+
for name, p in model.named_parameters():
123+
assert p.numel() > 0, f"{name} is empty"
124+
assert torch.isfinite(p).all(), f"{name} has NaN/Inf"
125+
126+
max_abs = p.abs().max().item()
127+
assert max_abs < 1e3, f"{name} extreme values: {max_abs}"
128+
129+
if name == "classifier.weight":
130+
torch.testing.assert_close(p.mean().item(), 0.0, atol=atol, rtol=rtol, msg=msg)
131+
torch.testing.assert_close(p.std().item(), model.config.initializer_range, atol=atol, rtol=rtol, msg=msg)
132+
133+
if name == "classifier.bias":
134+
torch.testing.assert_close(p, torch.zeros_like(p), msg=msg)
135+
136+
121137
def test_cuda_init():
122138
config = NVEsmConfig(**AutoConfig.from_pretrained("facebook/esm2_t6_8M_UR50D").to_dict())
123139

@@ -170,6 +186,18 @@ def test_meta_fp8_init(fp8_recipe):
170186
verify_model_parameters_initialized_correctly(model, should_be_fp8=True)
171187

172188

189+
def test_model_for_token_classification_init():
190+
config = NVEsmConfig(**AutoConfig.from_pretrained("nvidia/esm2_t6_8M_UR50D", trust_remote_code=True).to_dict())
191+
192+
set_seed(42)
193+
model = NVEsmForTokenClassification.from_pretrained(
194+
"nvidia/esm2_t6_8M_UR50D", config=config, dtype=torch.bfloat16, trust_remote_code=True
195+
)
196+
model.to("cuda")
197+
198+
verify_pretrained_model_sanity(model)
199+
200+
173201
@pytest.mark.parametrize("num_gpus", [1, pytest.param(2, marks=requires_multi_gpu)])
174202
def test_meta_device_init_after_fully_shard(num_gpus: int):
175203
cmd = [

bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,11 @@ def init_empty_weights(self):
284284
# Meta-device init seems to break weight tying, so we re-tie the weights here.
285285
self.tie_weights()
286286

287+
@classmethod
288+
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
289+
"""Override the default get_init_context method to allow for fp8 model initialization."""
290+
return []
291+
287292

288293
class NVEsmModel(NVEsmPreTrainedModel):
289294
"""The ESM Encoder-only protein language model.

bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,11 @@ def init_empty_weights(self):
284284
# Meta-device init seems to break weight tying, so we re-tie the weights here.
285285
self.tie_weights()
286286

287+
@classmethod
288+
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
289+
"""Override the default get_init_context method to allow for fp8 model initialization."""
290+
return []
291+
287292

288293
class NVEsmModel(NVEsmPreTrainedModel):
289294
"""The ESM Encoder-only protein language model.

bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,11 @@ def init_empty_weights(self):
284284
# Meta-device init seems to break weight tying, so we re-tie the weights here.
285285
self.tie_weights()
286286

287+
@classmethod
288+
def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
289+
"""Override the default get_init_context method to allow for fp8 model initialization."""
290+
return []
291+
287292

288293
class NVEsmModel(NVEsmPreTrainedModel):
289294
"""The ESM Encoder-only protein language model.

0 commit comments

Comments
 (0)