@@ -243,6 +243,7 @@ def __init__(self, config: LlamaConfig):
243
243
self .num_key_value_groups = self .num_heads // self .num_key_value_heads
244
244
self .pretraining_tp = config .pretraining_tp
245
245
self .max_position_embeddings = config .max_position_embeddings
246
+ self .rope_theta = config .rope_theta
246
247
247
248
if (self .head_dim * self .num_heads ) != self .hidden_size :
248
249
raise ValueError (
@@ -257,21 +258,25 @@ def __init__(self, config: LlamaConfig):
257
258
258
259
def _init_rope (self ):
259
260
if self .config .rope_scaling is None :
260
- self .rotary_emb = LlamaRotaryEmbedding (self .head_dim , max_position_embeddings = self .max_position_embeddings )
261
+ self .rotary_emb = LlamaRotaryEmbedding (
262
+ self .head_dim , max_position_embeddings = self .max_position_embeddings ,
263
+ base = self .rope_theta
264
+ )
261
265
else :
262
266
scaling_type = self .config .rope_scaling ["type" ]
263
267
scaling_factor = self .config .rope_scaling ["factor" ]
264
268
if scaling_type == "linear" :
265
269
self .rotary_emb = LlamaLinearScalingRotaryEmbedding (
266
- self .head_dim , max_position_embeddings = self .max_position_embeddings , scaling_factor = scaling_factor
270
+ self .head_dim , max_position_embeddings = self .max_position_embeddings ,
271
+ base = self .rope_theta , scaling_factor = scaling_factor
267
272
)
268
273
elif scaling_type == "dynamic" :
269
274
self .rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding (
270
- self .head_dim , max_position_embeddings = self .max_position_embeddings , scaling_factor = scaling_factor
275
+ self .head_dim , max_position_embeddings = self .max_position_embeddings ,
276
+ base = self .rope_theta , scaling_factor = scaling_factor
271
277
)
272
278
else :
273
279
raise ValueError (f"Unknown RoPE scaling type { scaling_type } " )
274
-
275
280
def _shape (self , tensor : torch .Tensor , seq_len : int , bsz : int ):
276
281
return tensor .view (bsz , seq_len , self .num_heads , self .head_dim ).transpose (1 , 2 ).contiguous ()
277
282
0 commit comments