@@ -76,33 +76,37 @@ def gemma(
76
76
TransformerDecoder: Instantiation of gemma model.
77
77
"""
78
78
rope = RotaryPositionalEmbeddings (dim = head_dim , max_seq_len = max_seq_len , base = rope_base )
79
- self_att = MultiHeadAttention (
80
- embed_dim = embed_dim ,
81
- num_heads = num_heads ,
82
- num_kv_heads = num_kv_heads ,
83
- head_dim = head_dim ,
84
- q_proj = nn .Linear (embed_dim , num_heads * head_dim , bias = False ),
85
- k_proj = nn .Linear (embed_dim , num_kv_heads * head_dim , bias = False ),
86
- v_proj = nn .Linear (embed_dim , num_kv_heads * head_dim , bias = False ),
87
- output_proj = nn .Linear (num_heads * head_dim , embed_dim , bias = False ),
88
- pos_embeddings = rope ,
89
- kv_cache = None ,
90
- max_seq_len = max_seq_len ,
91
- attn_dropout = attn_dropout ,
92
- )
93
- mlp = gemma_mlp (dim = embed_dim , hidden_dim = intermediate_dim )
94
- layer = TransformerSelfAttentionLayer (
95
- attn = self_att ,
96
- mlp = mlp ,
97
- sa_norm = GemmaRMSNorm (embed_dim , eps = norm_eps ),
98
- mlp_norm = GemmaRMSNorm (embed_dim , eps = norm_eps ),
99
- )
79
+
80
+ layers = nn .ModuleList ()
81
+ for _ in range (num_layers ):
82
+ self_att = MultiHeadAttention (
83
+ embed_dim = embed_dim ,
84
+ num_heads = num_heads ,
85
+ num_kv_heads = num_kv_heads ,
86
+ head_dim = head_dim ,
87
+ q_proj = nn .Linear (embed_dim , num_heads * head_dim , bias = False ),
88
+ k_proj = nn .Linear (embed_dim , num_kv_heads * head_dim , bias = False ),
89
+ v_proj = nn .Linear (embed_dim , num_kv_heads * head_dim , bias = False ),
90
+ output_proj = nn .Linear (num_heads * head_dim , embed_dim , bias = False ),
91
+ pos_embeddings = rope ,
92
+ kv_cache = None ,
93
+ max_seq_len = max_seq_len ,
94
+ attn_dropout = attn_dropout ,
95
+ )
96
+ mlp = gemma_mlp (dim = embed_dim , hidden_dim = intermediate_dim )
97
+ layer = TransformerSelfAttentionLayer (
98
+ attn = self_att ,
99
+ mlp = mlp ,
100
+ sa_norm = GemmaRMSNorm (embed_dim , eps = norm_eps ),
101
+ mlp_norm = GemmaRMSNorm (embed_dim , eps = norm_eps ),
102
+ )
103
+ layers .append (layer )
104
+
100
105
tok_embeddings = GemmaNormEmbeddings (vocab_size , embed_dim )
101
106
output_proj = TiedLinear (tok_embeddings )
102
107
model = TransformerDecoder (
103
108
tok_embeddings = tok_embeddings ,
104
- layers = layer ,
105
- num_layers = num_layers ,
109
+ layers = layers ,
106
110
max_seq_len = max_seq_len ,
107
111
num_heads = num_heads ,
108
112
output = output_proj ,
@@ -186,47 +190,50 @@ def lora_gemma(
186
190
TransformerDecoder: Instantiation of Gemma model with LoRA applied to
187
191
a subset of the attention projections in each layer.
188
192
"""
189
- self_attn = lora_gemma_self_attention (
190
- lora_modules = lora_attn_modules ,
191
- embed_dim = embed_dim ,
192
- head_dim = head_dim ,
193
- num_heads = num_heads ,
194
- num_kv_heads = num_kv_heads ,
195
- max_seq_len = max_seq_len ,
196
- attn_dropout = attn_dropout ,
197
- rope_base = rope_base ,
198
- lora_rank = lora_rank ,
199
- lora_alpha = lora_alpha ,
200
- lora_dropout = lora_dropout ,
201
- use_dora = use_dora ,
202
- quantize_base = quantize_base ,
203
- )
204
-
205
- if apply_lora_to_mlp :
206
- mlp = lora_gemma_mlp (
207
- dim = embed_dim ,
208
- hidden_dim = intermediate_dim ,
193
+ layers = nn .ModuleList ()
194
+ for _ in range (num_layers ):
195
+ self_attn = lora_gemma_self_attention (
196
+ lora_modules = lora_attn_modules ,
197
+ embed_dim = embed_dim ,
198
+ head_dim = head_dim ,
199
+ num_heads = num_heads ,
200
+ num_kv_heads = num_kv_heads ,
201
+ max_seq_len = max_seq_len ,
202
+ attn_dropout = attn_dropout ,
203
+ rope_base = rope_base ,
209
204
lora_rank = lora_rank ,
210
205
lora_alpha = lora_alpha ,
211
206
lora_dropout = lora_dropout ,
212
207
use_dora = use_dora ,
213
208
quantize_base = quantize_base ,
214
209
)
215
- else :
216
- mlp = gemma_mlp (dim = embed_dim , hidden_dim = intermediate_dim , quantize_base = quantize_base )
217
210
218
- layer = TransformerSelfAttentionLayer (
219
- attn = self_attn ,
220
- mlp = mlp ,
221
- sa_norm = GemmaRMSNorm (embed_dim , eps = norm_eps ),
222
- mlp_norm = GemmaRMSNorm (embed_dim , eps = norm_eps ),
223
- )
211
+ if apply_lora_to_mlp :
212
+ mlp = lora_gemma_mlp (
213
+ dim = embed_dim ,
214
+ hidden_dim = intermediate_dim ,
215
+ lora_rank = lora_rank ,
216
+ lora_alpha = lora_alpha ,
217
+ lora_dropout = lora_dropout ,
218
+ use_dora = use_dora ,
219
+ quantize_base = quantize_base ,
220
+ )
221
+ else :
222
+ mlp = gemma_mlp (dim = embed_dim , hidden_dim = intermediate_dim , quantize_base = quantize_base )
223
+
224
+ layer = TransformerSelfAttentionLayer (
225
+ attn = self_attn ,
226
+ mlp = mlp ,
227
+ sa_norm = GemmaRMSNorm (embed_dim , eps = norm_eps ),
228
+ mlp_norm = GemmaRMSNorm (embed_dim , eps = norm_eps ),
229
+ )
230
+ layers .append (layer )
231
+
224
232
tok_embeddings = GemmaNormEmbeddings (vocab_size , embed_dim )
225
233
output_proj = TiedLinear (tok_embeddings )
226
234
model = TransformerDecoder (
227
235
tok_embeddings = tok_embeddings ,
228
- layers = layer ,
229
- num_layers = num_layers ,
236
+ layers = layers ,
230
237
max_seq_len = max_seq_len ,
231
238
num_heads = num_heads ,
232
239
output = output_proj ,
0 commit comments