-
Notifications
You must be signed in to change notification settings - Fork 127
Expand file tree
/
Copy pathtest_modeling_esm_te.py
More file actions
276 lines (226 loc) · 10.2 KB
/
test_modeling_esm_te.py
File metadata and controls
276 lines (226 loc) · 10.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for ESM2 model using the common test library.
This file provides comprehensive tests for the ESM2 model including:
- Meta device initialization tests
- Golden value tests against HuggingFace reference models
- Conversion tests (HF ↔ TE)
- FP8 tests
- Model-specific tests
Most tests are inherited from the common test library to reduce duplication.
"""
from typing import Callable, Dict, List, Literal, Type
from unittest.mock import MagicMock
import torch
from torch import nn
from transformers import (
AutoTokenizer,
DataCollatorForLanguageModeling,
PretrainedConfig,
PreTrainedModel,
PreTrainedTokenizer,
)
from transformers.models.esm.modeling_esm import EsmForMaskedLM
from collator import DataCollatorWithFlattening
from convert import (
_pack_qkv_bias,
_pack_qkv_weight,
_pad_bias,
_pad_weights,
convert_esm_hf_to_te,
convert_esm_te_to_hf,
mapping,
)
from modeling_esm_te import NVEsmConfig, NVEsmForMaskedLM
from tests.common import BaseModelTest, TestTolerances
class TestESM2Model(BaseModelTest):
"""Model tester for ESM2.
This class provides ESM2-specific configuration for the common test suite.
"""
def get_model_class(self) -> Type[PreTrainedModel]:
"""Return the ESM2 TE model class."""
return NVEsmForMaskedLM
def get_config_class(self) -> Type[PretrainedConfig]:
"""Return the ESM2 config class."""
return NVEsmConfig
def get_upstream_model_id(self) -> str:
"""Return the upstream HuggingFace model ID."""
return "facebook/esm2_t6_8M_UR50D"
def get_upstream_model_revision(self) -> str:
"""Return the specific revision for the upstream model."""
return "c731040f"
def get_upstream_model_class(self) -> Type[PreTrainedModel]:
"""Return the upstream HuggingFace model class."""
return EsmForMaskedLM
def get_layer_path(self, model: PreTrainedModel) -> List[nn.Module]:
"""Return the list of transformer layers."""
return list(model.model.encoder.layers) # type: ignore
def get_reference_model_no_weights(self) -> PreTrainedModel:
"""For checkpoint conversion tests to pass, we need to remove the unused contact head."""
model = super().get_reference_model_no_weights()
del model.esm.contact_head
return model
def get_test_input_data(
self,
format: Literal["bshd", "thd"] = "bshd",
pad_to_multiple_of: int | None = None,
) -> Dict[str, torch.Tensor]:
"""Prepare test input data (protein sequences)."""
tokenizer = self.get_tokenizer()
# Use real protein sequences
test_proteins = [
"MLSATEKLSDYISSLFASVSIINSISTEDLFFLKLTCQTFSKDSEEYKAAYRILRGVQRGKVQIIEEALVS",
"MFVFFAGTLVNQDTLNFRDQLNINVVGTVRGIAQDASKYLEYAIDSV",
"MAATGSLILSDEEQAELIALAVRIVLACAGGSQNKELAAQLGVIETTVGEWRRRFAQNRVEGLRDEARPGAPSDDQ",
"MSAVLSAVASDDWTAFAKLVHPYVHWTADGITTRGRTRVMARLSGHDGVKPASSYELRDGQVYRWTS",
"MSDPAAEPPADTSGIAWRKSSYSGPNGNCVELAQISGDHVGIRNSRDLHGSVLTCTRAEFAALLCDIKAGRFDSLIL",
]
# Tokenize
tokenized = [tokenizer(p, truncation=True, max_length=1024) for p in test_proteins]
# Use data collator for MLM
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm_probability=0.15,
pad_to_multiple_of=pad_to_multiple_of if format == "bshd" else None,
seed=42,
)
if format == "thd":
data_collator = DataCollatorWithFlattening(
collator=data_collator,
pad_sequences_to_be_divisible_by=pad_to_multiple_of,
)
batch = data_collator(tokenized)
# Move to device
return {k: v.to("cuda") if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
def get_hf_to_te_converter(self) -> Callable:
"""Return the HF to TE conversion function."""
return convert_esm_hf_to_te
def get_te_to_hf_converter(self) -> Callable:
"""Return the TE to HF conversion function."""
return convert_esm_te_to_hf
def get_tolerances(self) -> TestTolerances:
"""Return ESM2-specific test tolerances."""
return TestTolerances(
golden_value_loss_atol=2e-2,
golden_value_loss_rtol=1e-2,
golden_value_logits_atol=2.0, # Higher tolerance needed after transformers PR#40370
golden_value_logits_rtol=1e-4,
cp_loss_atol=0.1,
cp_loss_rtol=0.05,
)
def get_tokenizer(self) -> PreTrainedTokenizer:
"""Return the ESM2 tokenizer."""
return AutoTokenizer.from_pretrained("esm_fast_tokenizer")
# ==================== ESM2-Specific Tests ====================
def test_convert_state_dict_explicit_check(self):
"""Test detailed state dict conversion and mapping."""
input_data = self.get_test_input_data()
model_hf = self.get_reference_model()
model_te = self.get_converted_te_model()
model_hf.to("cuda")
model_te.to("cuda")
input_data = {k: v.to("cuda") for k, v in input_data.items()}
with torch.no_grad():
outputs = model_te(**input_data)
assert outputs.loss
te_state_dict_keys = {
k for k in model_te.state_dict().keys() if not k.endswith("_extra_state") and not k.endswith("inv_freq")
}
# Check standard mapping
for k, v in mapping.items():
if "*" in k:
for i in range(model_hf.config.num_hidden_layers):
k_sub = k.replace("*", str(i))
v_sub = v.replace("*", str(i))
torch.testing.assert_close(
model_te.state_dict()[v_sub],
model_hf.state_dict()[k_sub],
msg=lambda x: f"{k} {i} is not close: {x}",
)
te_state_dict_keys.remove(v_sub)
else:
torch.testing.assert_close(
model_te.state_dict()[v],
model_hf.state_dict()[k],
msg=lambda x: f"{k} is not close: {x}",
)
te_state_dict_keys.remove(v)
# Check packed QKV weights
for i in range(model_hf.config.num_hidden_layers):
k = f"model.encoder.layers.{i}.self_attention.layernorm_qkv.weight"
v = [
f"esm.encoder.layer.{i}.attention.self.query.weight",
f"esm.encoder.layer.{i}.attention.self.key.weight",
f"esm.encoder.layer.{i}.attention.self.value.weight",
]
ctx_mock = MagicMock()
ctx_mock.target.config.num_attention_heads = model_hf.config.num_attention_heads
packed_weight = _pack_qkv_weight.transform(
ctx_mock,
model_hf.state_dict()[v[0]],
model_hf.state_dict()[v[1]],
model_hf.state_dict()[v[2]],
)
torch.testing.assert_close(packed_weight, model_te.state_dict()[k])
te_state_dict_keys.remove(k)
# Check packed QKV biases
for i in range(model_hf.config.num_hidden_layers):
k = f"model.encoder.layers.{i}.self_attention.layernorm_qkv.bias"
v = [
f"esm.encoder.layer.{i}.attention.self.query.bias",
f"esm.encoder.layer.{i}.attention.self.key.bias",
f"esm.encoder.layer.{i}.attention.self.value.bias",
]
ctx_mock = MagicMock()
ctx_mock.target.config.num_attention_heads = model_hf.config.num_attention_heads
packed_weight = _pack_qkv_bias.transform(
ctx_mock,
model_hf.state_dict()[v[0]],
model_hf.state_dict()[v[1]],
model_hf.state_dict()[v[2]],
)
torch.testing.assert_close(packed_weight, model_te.state_dict()[k])
te_state_dict_keys.remove(k)
# Check padded embeddings and LM head
ctx_mock = MagicMock()
ctx_mock.target.config.padded_vocab_size = model_te.config.padded_vocab_size
torch.testing.assert_close(
_pad_weights(ctx_mock, model_hf.state_dict()["esm.embeddings.word_embeddings.weight"]),
model_te.state_dict()["model.embeddings.word_embeddings.weight"],
)
torch.testing.assert_close(
_pad_weights(ctx_mock, model_hf.state_dict()["lm_head.decoder.weight"]),
model_te.state_dict()["lm_head.decoder.weight"],
)
torch.testing.assert_close(
_pad_bias.transform(ctx_mock, model_hf.state_dict()["lm_head.bias"]),
model_te.state_dict()["lm_head.decoder.bias"],
)
te_state_dict_keys.remove("model.embeddings.word_embeddings.weight")
te_state_dict_keys.remove("lm_head.decoder.weight")
te_state_dict_keys.remove("lm_head.decoder.bias")
assert len(te_state_dict_keys) == 0
# Check that the tied weights are the same
assert (
model_hf.state_dict()["esm.embeddings.word_embeddings.weight"].data_ptr()
== model_hf.state_dict()["lm_head.decoder.weight"].data_ptr()
)
assert (
model_te.state_dict()["model.embeddings.word_embeddings.weight"].data_ptr()
== model_te.state_dict()["lm_head.decoder.weight"].data_ptr()
)
def create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1):
"""These are unused for non-autoregressive models."""
pass