Skip to content

Commit 41d4bf2

Browse files
authored
Merge pull request #5 from togethercomputer/support-code-llama
llama: support RoPE theta for codellama
2 parents d0027d5 + 33a236b commit 41d4bf2

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

src/transformers/models/llama/configuration_llama.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class LlamaConfig(PretrainedConfig):
7878
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
7979
Whether to tie weight embeddings
8080
rope_scaling (`Dict`, *optional*):
81-
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
81+
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
8282
strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
8383
is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
8484
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
@@ -122,6 +122,7 @@ def __init__(
122122
pretraining_tp=1,
123123
tie_word_embeddings=False,
124124
rope_scaling=None,
125+
rope_theta=10000,
125126
**kwargs,
126127
):
127128
self.vocab_size = vocab_size
@@ -143,6 +144,7 @@ def __init__(
143144
self.use_cache = use_cache
144145
self.rope_scaling = rope_scaling
145146
self._rope_scaling_validation()
147+
self.rope_theta = rope_theta
146148

147149
super().__init__(
148150
pad_token_id=pad_token_id,

src/transformers/models/llama/modeling_llama.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ def __init__(self, config: LlamaConfig):
243243
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
244244
self.pretraining_tp = config.pretraining_tp
245245
self.max_position_embeddings = config.max_position_embeddings
246+
self.rope_theta = config.rope_theta
246247

247248
if (self.head_dim * self.num_heads) != self.hidden_size:
248249
raise ValueError(
@@ -257,21 +258,25 @@ def __init__(self, config: LlamaConfig):
257258

258259
def _init_rope(self):
259260
if self.config.rope_scaling is None:
260-
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
261+
self.rotary_emb = LlamaRotaryEmbedding(
262+
self.head_dim, max_position_embeddings=self.max_position_embeddings,
263+
base=self.rope_theta
264+
)
261265
else:
262266
scaling_type = self.config.rope_scaling["type"]
263267
scaling_factor = self.config.rope_scaling["factor"]
264268
if scaling_type == "linear":
265269
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
266-
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
270+
self.head_dim, max_position_embeddings=self.max_position_embeddings,
271+
base=self.rope_theta, scaling_factor=scaling_factor
267272
)
268273
elif scaling_type == "dynamic":
269274
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
270-
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
275+
self.head_dim, max_position_embeddings=self.max_position_embeddings,
276+
base=self.rope_theta, scaling_factor=scaling_factor
271277
)
272278
else:
273279
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
274-
275280
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
276281
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
277282

0 commit comments

Comments
 (0)