Skip to content

Commit 23896c3

Browse files
authored
Update model builders (#2282)
1 parent 3cceb86 commit 23896c3

14 files changed

+570
-512
lines changed

tests/recipes/dev/test_generate_v2.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,7 @@ def test_llama2_generate_results(self, caplog, monkeypatch, tmpdir):
5555
# this is gibberish b/c the model is random weights, but it's
5656
# the expected value for what we currently have in V2
5757
# this test should catch any changes to the generate recipe that affect output
58-
expected_output = (
59-
"Country maior Connection Kohćutsójcustomulas Sometimes Security"
60-
)
58+
expected_output = "Pietroместkap щotimes rivers cache НиtringindexPathNAME"
6159

6260
logs = caplog.text
6361
assert expected_output in logs

tests/torchtune/models/t5/test_t5_encoder.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -51,24 +51,24 @@ def test_forward(self, model, inputs):
5151
expected = torch.tensor(
5252
[
5353
[
54-
[0.3670, 0.2938],
55-
[0.3692, 0.2921],
56-
[0.3611, 0.2984],
57-
[0.4207, 0.2437],
58-
[0.3447, 0.3106],
59-
[0.3383, 0.3150],
60-
[0.3727, 0.2892],
61-
[0.3996, 0.2653],
54+
[0.1940, 0.5625],
55+
[0.1893, 0.5681],
56+
[0.2020, 0.5522],
57+
[0.2547, 0.4681],
58+
[0.1769, 0.5822],
59+
[0.2737, 0.4281],
60+
[0.2828, 0.4066],
61+
[0.2841, 0.4033],
6262
],
6363
[
64-
[0.3855, 0.2783],
65-
[0.2627, 0.3581],
66-
[0.3601, 0.2992],
67-
[0.3473, 0.3087],
68-
[0.3549, 0.3032],
69-
[0.2871, 0.3459],
70-
[0.2753, 0.3520],
71-
[0.2285, 0.3728],
64+
[0.1796, 0.5792],
65+
[0.2020, 0.5523],
66+
[0.2209, 0.5258],
67+
[0.2802, 0.4128],
68+
[0.2923, 0.3817],
69+
[0.2677, 0.4414],
70+
[0.2458, 0.4847],
71+
[0.1923, 0.5645],
7272
],
7373
]
7474
)

torchtune/models/gemma/_component_builders.py

+60-53
Original file line numberDiff line numberDiff line change
@@ -76,33 +76,37 @@ def gemma(
7676
TransformerDecoder: Instantiation of gemma model.
7777
"""
7878
rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base)
79-
self_att = MultiHeadAttention(
80-
embed_dim=embed_dim,
81-
num_heads=num_heads,
82-
num_kv_heads=num_kv_heads,
83-
head_dim=head_dim,
84-
q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False),
85-
k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
86-
v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
87-
output_proj=nn.Linear(num_heads * head_dim, embed_dim, bias=False),
88-
pos_embeddings=rope,
89-
kv_cache=None,
90-
max_seq_len=max_seq_len,
91-
attn_dropout=attn_dropout,
92-
)
93-
mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim)
94-
layer = TransformerSelfAttentionLayer(
95-
attn=self_att,
96-
mlp=mlp,
97-
sa_norm=GemmaRMSNorm(embed_dim, eps=norm_eps),
98-
mlp_norm=GemmaRMSNorm(embed_dim, eps=norm_eps),
99-
)
79+
80+
layers = nn.ModuleList()
81+
for _ in range(num_layers):
82+
self_att = MultiHeadAttention(
83+
embed_dim=embed_dim,
84+
num_heads=num_heads,
85+
num_kv_heads=num_kv_heads,
86+
head_dim=head_dim,
87+
q_proj=nn.Linear(embed_dim, num_heads * head_dim, bias=False),
88+
k_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
89+
v_proj=nn.Linear(embed_dim, num_kv_heads * head_dim, bias=False),
90+
output_proj=nn.Linear(num_heads * head_dim, embed_dim, bias=False),
91+
pos_embeddings=rope,
92+
kv_cache=None,
93+
max_seq_len=max_seq_len,
94+
attn_dropout=attn_dropout,
95+
)
96+
mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim)
97+
layer = TransformerSelfAttentionLayer(
98+
attn=self_att,
99+
mlp=mlp,
100+
sa_norm=GemmaRMSNorm(embed_dim, eps=norm_eps),
101+
mlp_norm=GemmaRMSNorm(embed_dim, eps=norm_eps),
102+
)
103+
layers.append(layer)
104+
100105
tok_embeddings = GemmaNormEmbeddings(vocab_size, embed_dim)
101106
output_proj = TiedLinear(tok_embeddings)
102107
model = TransformerDecoder(
103108
tok_embeddings=tok_embeddings,
104-
layers=layer,
105-
num_layers=num_layers,
109+
layers=layers,
106110
max_seq_len=max_seq_len,
107111
num_heads=num_heads,
108112
output=output_proj,
@@ -186,47 +190,50 @@ def lora_gemma(
186190
TransformerDecoder: Instantiation of Gemma model with LoRA applied to
187191
a subset of the attention projections in each layer.
188192
"""
189-
self_attn = lora_gemma_self_attention(
190-
lora_modules=lora_attn_modules,
191-
embed_dim=embed_dim,
192-
head_dim=head_dim,
193-
num_heads=num_heads,
194-
num_kv_heads=num_kv_heads,
195-
max_seq_len=max_seq_len,
196-
attn_dropout=attn_dropout,
197-
rope_base=rope_base,
198-
lora_rank=lora_rank,
199-
lora_alpha=lora_alpha,
200-
lora_dropout=lora_dropout,
201-
use_dora=use_dora,
202-
quantize_base=quantize_base,
203-
)
204-
205-
if apply_lora_to_mlp:
206-
mlp = lora_gemma_mlp(
207-
dim=embed_dim,
208-
hidden_dim=intermediate_dim,
193+
layers = nn.ModuleList()
194+
for _ in range(num_layers):
195+
self_attn = lora_gemma_self_attention(
196+
lora_modules=lora_attn_modules,
197+
embed_dim=embed_dim,
198+
head_dim=head_dim,
199+
num_heads=num_heads,
200+
num_kv_heads=num_kv_heads,
201+
max_seq_len=max_seq_len,
202+
attn_dropout=attn_dropout,
203+
rope_base=rope_base,
209204
lora_rank=lora_rank,
210205
lora_alpha=lora_alpha,
211206
lora_dropout=lora_dropout,
212207
use_dora=use_dora,
213208
quantize_base=quantize_base,
214209
)
215-
else:
216-
mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim, quantize_base=quantize_base)
217210

218-
layer = TransformerSelfAttentionLayer(
219-
attn=self_attn,
220-
mlp=mlp,
221-
sa_norm=GemmaRMSNorm(embed_dim, eps=norm_eps),
222-
mlp_norm=GemmaRMSNorm(embed_dim, eps=norm_eps),
223-
)
211+
if apply_lora_to_mlp:
212+
mlp = lora_gemma_mlp(
213+
dim=embed_dim,
214+
hidden_dim=intermediate_dim,
215+
lora_rank=lora_rank,
216+
lora_alpha=lora_alpha,
217+
lora_dropout=lora_dropout,
218+
use_dora=use_dora,
219+
quantize_base=quantize_base,
220+
)
221+
else:
222+
mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim, quantize_base=quantize_base)
223+
224+
layer = TransformerSelfAttentionLayer(
225+
attn=self_attn,
226+
mlp=mlp,
227+
sa_norm=GemmaRMSNorm(embed_dim, eps=norm_eps),
228+
mlp_norm=GemmaRMSNorm(embed_dim, eps=norm_eps),
229+
)
230+
layers.append(layer)
231+
224232
tok_embeddings = GemmaNormEmbeddings(vocab_size, embed_dim)
225233
output_proj = TiedLinear(tok_embeddings)
226234
model = TransformerDecoder(
227235
tok_embeddings=tok_embeddings,
228-
layers=layer,
229-
num_layers=num_layers,
236+
layers=layers,
230237
max_seq_len=max_seq_len,
231238
num_heads=num_heads,
232239
output=output_proj,

torchtune/models/gemma2/_component_builders.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from torch import nn
87
import torch
9-
from typing import List
8+
from torch import nn
109
from torchtune.modules.common_utils import _register_reparametrize_state_dict_hooks
1110
from typing import List, Optional
1211

@@ -116,7 +115,6 @@ def gemma2(
116115
rope = RotaryPositionalEmbeddings(dim=head_dim, max_seq_len=max_seq_len, base=rope_base)
117116

118117
layers = torch.nn.ModuleList()
119-
120118
for layer_idx in range(num_layers):
121119

122120
mlp = gemma_mlp(dim=embed_dim, hidden_dim=intermediate_dim)
@@ -149,6 +147,7 @@ def gemma2(
149147
mlp_scale=GemmaRMSNorm(embed_dim, eps=norm_eps),
150148
)
151149
layers.append(layer)
150+
152151
tok_embeddings = GemmaNormEmbeddings(vocab_size, embed_dim)
153152
output_proj = TiedLinear(tok_embeddings)
154153
model = TransformerDecoder(
@@ -231,8 +230,7 @@ def lora_gemma2(
231230
tok_embeddings = GemmaNormEmbeddings(vocab_size, embed_dim)
232231
output_proj = TiedLinear(tok_embeddings)
233232

234-
layers = torch.nn.ModuleList()
235-
233+
layers = nn.ModuleList()
236234
for layer_idx in range(num_layers):
237235
if apply_lora_to_mlp:
238236
mlp = lora_gemma_mlp(

0 commit comments

Comments
 (0)