-
Notifications
You must be signed in to change notification settings - Fork 586
Expand file tree
/
Copy pathtest_hooked_encoder_decoder.py
More file actions
424 lines (302 loc) · 16.3 KB
/
test_hooked_encoder_decoder.py
File metadata and controls
424 lines (302 loc) · 16.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
import pytest
import torch
from jaxtyping import Float
from torch.testing import assert_close
from transformers import AutoTokenizer, T5ForConditionalGeneration
from transformer_lens import HookedEncoderDecoder
# Skip entire module in coverage tests due to test pollution issues
pytestmark = pytest.mark.skip(reason="Temporarily skipped due to CI test pollution issues")
MODEL_NAME = "t5-small"
@pytest.fixture(scope="module")
def our_model():
return HookedEncoderDecoder.from_pretrained(MODEL_NAME, device="cpu")
@pytest.fixture(scope="module")
def huggingface_model():
return T5ForConditionalGeneration.from_pretrained(MODEL_NAME).eval()
@pytest.fixture(scope="module")
def tokenizer():
return AutoTokenizer.from_pretrained(MODEL_NAME)
@pytest.fixture
def hello_world_tokens(tokenizer):
return tokenizer("Hello, world!", return_tensors="pt")["input_ids"]
@pytest.fixture
def decoder_input_ids(tokenizer):
return torch.LongTensor([[tokenizer.pad_token_id]])
def test_full_model(our_model, huggingface_model, tokenizer, decoder_input_ids):
sequences = ["Hello, world!", "this is another sequence of tokens"]
tokenized = tokenizer(sequences, return_tensors="pt", padding=True)
decoder_ids = torch.stack([decoder_input_ids[0]] * len(sequences), dim=0)
input_ids = tokenized["input_ids"]
attention_mask = tokenized["attention_mask"]
huggingface_model_out = huggingface_model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_ids,
).logits
our_model_out = our_model(
input_ids,
decoder_input=decoder_ids,
one_zero_attention_mask=attention_mask,
)
assert_close(huggingface_model_out, our_model_out, rtol=1.3e-6, atol=4e-5)
def test_encoder(our_model, huggingface_model, hello_world_tokens):
our_embeds = our_model.embed(hello_world_tokens)
pos_bias = our_model.encoder[0].attn.compute_relative_attention_bias(
hello_world_tokens.shape[1], hello_world_tokens.shape[1]
)
for our_layer in our_model.encoder:
our_embeds = our_layer(resid_pre=our_embeds, position_bias=pos_bias)
our_encoder_out = our_model.encoder_final_ln(our_embeds)
huggingface_encoder_out = huggingface_model.encoder(hello_world_tokens).last_hidden_state
assert_close(our_encoder_out, huggingface_encoder_out, rtol=1.3e-6, atol=4e-5)
def test_decoder(our_model, huggingface_model, hello_world_tokens, decoder_input_ids):
encoder_hidden = huggingface_model.encoder(hello_world_tokens)[0]
embeds = our_model.embed(decoder_input_ids)
pos_bias = our_model.decoder[0].attn.compute_relative_attention_bias(
decoder_input_ids.shape[1], decoder_input_ids.shape[1]
)
for layer in our_model.decoder:
embeds = layer(embeds, encoder_hidden_states=encoder_hidden, position_bias=pos_bias)
our_decoder_out = our_model.decoder_final_ln(embeds)
hf_decoder_out = huggingface_model.decoder(
decoder_input_ids, encoder_hidden_states=encoder_hidden
)[0]
assert_close(our_decoder_out, hf_decoder_out, rtol=1.3e-6, atol=4e-5)
def test_embed_one_sentence(our_model, huggingface_model, hello_world_tokens):
huggingface_embed = huggingface_model.encoder.embed_tokens
our_embed = our_model.embed
huggingface_embed_out = huggingface_embed(hello_world_tokens)[0]
our_embed_out = our_embed(hello_world_tokens).squeeze(0)
assert_close(huggingface_embed_out, our_embed_out)
def test_relative_attention_bias(our_model, huggingface_model, hello_world_tokens):
# it is used only in self attention of first layer of encoder
huggingface_embed = huggingface_model.encoder.embed_tokens
huggingface_attn = huggingface_model.encoder.block[0].layer[0].SelfAttention
our_attn = our_model.encoder[0].attn
assert huggingface_attn.has_relative_attention_bias
assert our_attn.has_relative_attention_bias
assert (
our_attn.relative_attention_num_buckets == huggingface_attn.relative_attention_num_buckets
)
assert (
our_attn.relative_attention_max_distance == huggingface_attn.relative_attention_max_distance
)
assert_close(our_attn.rel_pos_bias.weight, huggingface_attn.relative_attention_bias.weight)
input_len = hello_world_tokens.shape[1]
our_bias = our_attn.compute_relative_attention_bias(input_len, input_len)
hf_bias = huggingface_attn.compute_bias(input_len, input_len)
assert_close(our_bias, hf_bias, rtol=1e-5, atol=1e-5)
embed_out = huggingface_embed(hello_world_tokens)
cache_position = torch.arange(input_len)
huggingface_attn_out = huggingface_attn(embed_out, cache_position=cache_position)[0]
our_attn_out = our_attn(embed_out, embed_out, embed_out, position_bias=our_bias)
assert_close(our_attn_out, huggingface_attn_out, rtol=7.4e-4, atol=1e-5)
def test_relative_attention_layer(our_model, huggingface_model, hello_world_tokens):
# it is used only in self attention of first layer of encoder
hf_block = huggingface_model.encoder.block[0].layer[0]
our_block = our_model.encoder[0]
resid = huggingface_model.encoder.embed_tokens(hello_world_tokens)
input_len = hello_world_tokens.shape[1]
our_bias = our_block.attn.compute_relative_attention_bias(input_len, input_len)
resid_norm = our_block.ln1(resid)
our_out = resid + our_block.attn(resid_norm, resid_norm, resid_norm, position_bias=our_bias)
cache_position = torch.arange(input_len)
hf_out = hf_block(resid, cache_position=cache_position)[0]
assert_close(our_out, hf_out, rtol=1.3e-6, atol=4e-5)
def test_attention(our_model, huggingface_model, hello_world_tokens):
huggingface_embed = huggingface_model.encoder.embed_tokens
huggingface_attn = huggingface_model.encoder.block[1].layer[0].SelfAttention
embed_out = huggingface_embed(hello_world_tokens)
our_attn = our_model.encoder[1].attn
our_attn_out = our_attn(embed_out, embed_out, embed_out)
input_len = hello_world_tokens.shape[1]
cache_position = torch.arange(input_len)
huggingface_attn_out = huggingface_attn(embed_out, cache_position=cache_position)[0]
assert_close(our_attn_out, huggingface_attn_out, rtol=5e-4, atol=1e-5)
def test_decoder_attention(our_model, huggingface_model, hello_world_tokens):
huggingface_embed = huggingface_model.decoder.embed_tokens
huggingface_attn = huggingface_model.decoder.block[1].layer[0].SelfAttention
embed_out = huggingface_embed(hello_world_tokens)
our_attn = our_model.decoder[1].attn
our_attn_out = our_attn(embed_out, embed_out, embed_out)
input_len = hello_world_tokens.shape[1]
cache_position = torch.arange(input_len)
huggingface_attn_out = huggingface_attn(embed_out, cache_position=cache_position)[0]
assert_close(our_attn_out, huggingface_attn_out, rtol=5e-4, atol=1e-5)
def test_attention_layer(our_model, huggingface_model, hello_world_tokens):
huggingface_embed = huggingface_model.encoder.embed_tokens
huggingface_attn = huggingface_model.encoder.block[1].layer[0]
embed_out = huggingface_embed(hello_world_tokens)
our_attn = our_model.encoder[1].attn
norm_embed = our_model.encoder[1].ln1(embed_out)
our_attn_out = our_attn(norm_embed, norm_embed, norm_embed) + embed_out
input_len = hello_world_tokens.shape[1]
cache_position = torch.arange(input_len)
huggingface_attn_out = huggingface_attn(embed_out, cache_position=cache_position)[0]
assert_close(our_attn_out, huggingface_attn_out, rtol=2e-4, atol=1e-5)
def test_decoder_attention_layer(our_model, huggingface_model, hello_world_tokens):
huggingface_embed = huggingface_model.decoder.embed_tokens
huggingface_attn = huggingface_model.decoder.block[1].layer[0]
embed_out = huggingface_embed(hello_world_tokens)
our_attn = our_model.decoder[1].attn
norm_embed = our_model.decoder[1].ln1(embed_out)
our_attn_out = our_attn(norm_embed, norm_embed, norm_embed) + embed_out
input_len = hello_world_tokens.shape[1]
cache_position = torch.arange(input_len)
huggingface_attn_out = huggingface_attn(embed_out, cache_position=cache_position)[0]
assert_close(our_attn_out, huggingface_attn_out, rtol=3e-4, atol=4e-5)
def test_cross_attention(our_model, huggingface_model, hello_world_tokens, decoder_input_ids):
encoder_hidden = huggingface_model.encoder(hello_world_tokens).last_hidden_state
decoder_hidden = huggingface_model.decoder.embed_tokens(decoder_input_ids)
huggingface_cross_attn = huggingface_model.decoder.block[0].layer[1].EncDecAttention
our_cross_attn = our_model.decoder[0].cross_attn
our_cross_attn_out = our_cross_attn(decoder_hidden, encoder_hidden, encoder_hidden)
huggingface_cross_attn_out = huggingface_cross_attn(
decoder_hidden, key_value_states=encoder_hidden, cache_position=encoder_hidden
)[0]
assert_close(our_cross_attn_out, huggingface_cross_attn_out, rtol=2e-3, atol=1e-4)
def test_cross_attention_layer(our_model, huggingface_model, hello_world_tokens, decoder_input_ids):
encoder_hidden = huggingface_model.encoder(hello_world_tokens).last_hidden_state
decoder_hidden = huggingface_model.decoder.embed_tokens(decoder_input_ids)
hf_layer = huggingface_model.decoder.block[0].layer[1]
our_layer = our_model.decoder[0]
# assert ln weights are the same
assert_close(hf_layer.layer_norm.weight, our_layer.ln2.w)
our_cross_attn_out = (
our_layer.cross_attn(our_layer.ln2(decoder_hidden), encoder_hidden, encoder_hidden)
+ decoder_hidden
)
huggingface_cross_attn_out = hf_layer(
decoder_hidden, key_value_states=encoder_hidden, cache_position=encoder_hidden
)[0]
assert_close(our_cross_attn_out, huggingface_cross_attn_out, rtol=2e-4, atol=1e-5)
def test_encoder_block(our_model, huggingface_model, hello_world_tokens):
huggingface_embed = huggingface_model.encoder.embed_tokens
huggingface_block = huggingface_model.encoder.block[1]
our_block = our_model.encoder[1]
embed_out = huggingface_embed(hello_world_tokens)
input_len = hello_world_tokens.shape[1]
cache_position = torch.arange(input_len)
hf_out = huggingface_block(embed_out, cache_position=cache_position)[0]
our_out = our_block(embed_out)
assert_close(our_out, hf_out, rtol=2e-4, atol=2e-5)
def test_decoder_block(our_model, huggingface_model, hello_world_tokens, decoder_input_ids):
huggingface_embed = huggingface_model.decoder.embed_tokens
huggingface_block = huggingface_model.decoder.block[1]
our_block = our_model.decoder[1]
encoder_hidden = huggingface_model.encoder(hello_world_tokens)[0]
input_len = decoder_input_ids.shape[1]
cache_position = torch.arange(input_len)
decoder_hidden = huggingface_model.decoder.block[0](
huggingface_embed(decoder_input_ids), cache_position=cache_position
)[0]
our_out = our_block(decoder_hidden, encoder_hidden_states=encoder_hidden)
hf_out = huggingface_block(
decoder_hidden, encoder_hidden_states=encoder_hidden, cache_position=encoder_hidden
)[0]
assert_close(hf_out, our_out, rtol=2e-4, atol=2e-5)
def test_layernorm(our_model, huggingface_model, hello_world_tokens):
huggingface_embed = huggingface_model.encoder.embed_tokens
huggingface_layernorm = huggingface_model.encoder.block[0].layer[0].layer_norm
our_layernorm = our_model.encoder[0].ln1
embed_out = huggingface_embed(hello_world_tokens)
our_layernorm_out = our_layernorm(embed_out)
huggingface_layernorm_out = huggingface_layernorm(embed_out)
assert_close(our_layernorm_out, huggingface_layernorm_out)
def test_unembed(our_model, huggingface_model, hello_world_tokens):
huggingface_model_hidden = huggingface_model.decoder(hello_world_tokens).last_hidden_state
our_model_logits = our_model.unembed(huggingface_model_hidden)
huggingface_model_logits = huggingface_model.lm_head(huggingface_model_hidden)
assert_close(our_model_logits, huggingface_model_logits, rtol=1.3e-3, atol=1e-5)
def test_run_with_cache(our_model, hello_world_tokens, decoder_input_ids):
logits, cache = our_model.run_with_cache(hello_world_tokens, decoder_input=decoder_input_ids)
# check that an arbitrary subset of the keys exist and have the right shape
seq_len = 5
generated_len = 1
assert "hook_embed" in cache
assert cache["hook_embed"].shape == (1, seq_len, 512)
assert "encoder.1.attn.hook_v" in cache
assert cache["encoder.1.attn.hook_v"].shape == (1, seq_len, 8, 64)
assert "encoder.3.attn.hook_attn_scores" in cache
assert cache["encoder.3.attn.hook_attn_scores"].shape == (1, 8, seq_len, seq_len)
assert "decoder.0.cross_attn.hook_k" in cache
assert cache["decoder.0.cross_attn.hook_attn_scores"].shape == (
1,
8,
generated_len,
seq_len,
)
assert "decoder.3.hook_resid_post" in cache
assert cache["decoder.3.hook_resid_post"].shape == (1, generated_len, 512)
def test_from_pretrained_revision():
"""
Check that the from_pretrained parameter `revision` (= git version) works
"""
_ = HookedEncoderDecoder.from_pretrained(MODEL_NAME, revision="main")
try:
_ = HookedEncoderDecoder.from_pretrained(MODEL_NAME, revision="inexistent_branch_name")
except:
pass
else:
raise AssertionError("Should have raised an error")
def test_predictions(our_model, huggingface_model, tokenizer, decoder_input_ids):
input_ids = tokenizer("My name is Wolfgang and I live in Berlin", return_tensors="pt")[
"input_ids"
]
def get_predictions(logits: Float[torch.Tensor, "batch pos d_vocab"]):
predicted_tokens = logits[0].argmax(dim=-1)
return tokenizer.batch_decode(predicted_tokens)
our_model_logits = our_model(input_ids, decoder_input=decoder_input_ids)
our_prediction = get_predictions(our_model_logits)
huggingface_model_logits = huggingface_model(
input_ids, decoder_input_ids=decoder_input_ids
).logits
huggingface_prediction = get_predictions(huggingface_model_logits)
assert our_prediction == huggingface_prediction
def test_predictions_string_input(our_model, huggingface_model, tokenizer):
prompt = "translate English to German: Hello, do you like bananas?"
encodings = tokenizer(prompt, return_tensors="pt")
tokens = encodings.input_ids
batch_size, seq_len = tokens.shape
decoder_input_ids = torch.full((batch_size, 1), tokenizer.pad_token_id)
our_model_logits = our_model(prompt)
huggingface_model_logits = huggingface_model(
input_ids=tokens,
attention_mask=encodings.attention_mask,
decoder_input_ids=decoder_input_ids,
).logits
assert_close(our_model_logits, huggingface_model_logits, rtol=1e-5, atol=1e-5)
def test_predictions_string_list_input(our_model, huggingface_model, tokenizer):
prompt = [
"translate English to German: Hello, do you like bananas?",
"translate English to French: Hello, do you like bananas?",
"translate English to Spanish: Hello, do you like bananas?",
]
encodings = tokenizer(prompt, return_tensors="pt")
tokens = encodings.input_ids
batch_size, seq_len = tokens.shape
decoder_input_ids = torch.full((batch_size, 1), tokenizer.pad_token_id)
our_model_logits = our_model(prompt)
huggingface_model_logits = huggingface_model(
input_ids=tokens,
attention_mask=encodings.attention_mask,
decoder_input_ids=decoder_input_ids,
).logits
assert_close(our_model_logits, huggingface_model_logits, rtol=1e-5, atol=1e-5)
def test_generate(our_model, huggingface_model, tokenizer):
prompt = "translate English to German: Hello, do you like bananas?"
encodings = tokenizer(prompt, return_tensors="pt")
our_generation = our_model.generate(prompt, do_sample=False, max_new_tokens=20)
huggingface_generated_tokens = huggingface_model.generate(
input_ids=encodings.input_ids,
attention_mask=encodings.attention_mask,
do_sample=False,
)[0]
huggingface_generation = tokenizer.decode(
huggingface_generated_tokens, skip_special_tokens=True
)
assert our_generation.lower() == huggingface_generation.lower()
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires a CUDA device")
def test_cuda(hello_world_tokens, decoder_input_ids):
model = HookedEncoderDecoder.from_pretrained(MODEL_NAME)
model(hello_world_tokens, decoder_input=decoder_input_ids.cuda())