-
Notifications
You must be signed in to change notification settings - Fork 28.8k
Llama/GPTNeoX: add RoPE scaling #24653
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
e7d743f
11a25da
022c977
c12aaf1
d7e7636
8c52524
c7161b4
7c02d92
f385927
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -89,32 +89,51 @@ def forward(self, hidden_states): | |
|
||
|
||
class LlamaRotaryEmbedding(torch.nn.Module): | ||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): | ||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, rope_scaling=None): | ||
super().__init__() | ||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) | ||
self.scaling_name = rope_scaling.get("name", "") if rope_scaling is not None else "" | ||
self.scaling_factor = rope_scaling.get("factor") if rope_scaling is not None else None | ||
|
||
self.dim = dim | ||
self.max_position_embeddings = max_position_embeddings | ||
self.base = base | ||
if self.scaling_name == "ntk": | ||
self.base = self.base * self.scaling_factor ** (self.dim / (self.dim - 2)) | ||
|
||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) | ||
self.register_buffer("inv_freq", inv_freq) | ||
|
||
# Build here to make `torch.jit.trace` work. | ||
self.max_seq_len_cached = max_position_embeddings | ||
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) | ||
self._set_cos_sin_cache( | ||
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() | ||
) | ||
|
||
def _set_cos_sin_cache(self, seq_len, device, dtype): | ||
self.max_seq_len_cached = seq_len | ||
|
||
if self.scaling_name == "dynamic": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For dynamic, we want to do the scaling only when if self.scaling_name == "dynamic" and seq_len > self.max_position_embeddings: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jquesnelle that's a great catch, and completely missed in my attempts to unify and refactor the cos/sin cache! |
||
base = self.base * ( | ||
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) | ||
) ** (self.dim / (self.dim - 2)) | ||
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) | ||
self.register_buffer("inv_freq", inv_freq) | ||
|
||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) | ||
|
||
if self.scaling_name == "linear": | ||
t = t / self.scaling_factor | ||
|
||
freqs = torch.einsum("i,j->ij", t, self.inv_freq) | ||
# Different from paper, but it uses a different permutation in order to obtain the same calculation | ||
emb = torch.cat((freqs, freqs), dim=-1) | ||
dtype = torch.get_default_dtype() | ||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) | ||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) | ||
|
||
def forward(self, x, seq_len=None): | ||
# x: [bs, num_attention_heads, seq_len, head_size] | ||
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. | ||
if seq_len > self.max_seq_len_cached: | ||
self.max_seq_len_cached = seq_len | ||
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) | ||
freqs = torch.einsum("i,j->ij", t, self.inv_freq) | ||
# Different from paper, but it uses a different permutation in order to obtain the same calculation | ||
emb = torch.cat((freqs, freqs), dim=-1).to(x.device) | ||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False) | ||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False) | ||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) | ||
|
||
return ( | ||
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), | ||
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), | ||
|
@@ -176,7 +195,11 @@ def __init__(self, config: LlamaConfig): | |
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) | ||
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) | ||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) | ||
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) | ||
self.rotary_emb = LlamaRotaryEmbedding( | ||
dim=self.head_dim, | ||
max_position_embeddings=self.max_position_embeddings, | ||
rope_scaling=config.rope_scaling, | ||
) | ||
|
||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): | ||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -67,6 +67,12 @@ class OpenLlamaConfig(PretrainedConfig): | |
relevant if `config.is_decoder=True`. | ||
tie_word_embeddings(`bool`, *optional*, defaults to `False`): | ||
Whether to tie weight embeddings | ||
rope_scaling (`Dict`, *optional*): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (changes in open_llama are copy/paste) |
||
Experimental feature -- dictionary containing the scaling configuration for the RoPE embeddings. Currently | ||
supports three scaling strategies: linear, ntk, and dynamic. Their scaling factor must be an float greater | ||
than 1. The expected format is `{"name": strategy name, "factor": scaling factor}`. See the following | ||
thread for more information on how these scaling strategies behave: | ||
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/ | ||
Example: | ||
|
||
```python | ||
|
@@ -104,6 +110,7 @@ def __init__( | |
attention_dropout_prob=0.1, | ||
use_stable_embedding=True, | ||
shared_input_output_embedding=True, | ||
rope_scaling=None, | ||
**kwargs, | ||
): | ||
self.vocab_size = vocab_size | ||
|
@@ -123,6 +130,23 @@ def __init__( | |
self.attention_dropout_prob = attention_dropout_prob | ||
self.use_stable_embedding = use_stable_embedding | ||
self.shared_input_output_embedding = shared_input_output_embedding | ||
self.rope_scaling = rope_scaling | ||
|
||
# RoPE scaling validation | ||
if self.rope_scaling is not None: | ||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: | ||
raise ValueError( | ||
f"`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, got {self.rope_scaling}" | ||
) | ||
rope_scaling_name = self.rope_scaling.get("name", None) | ||
rope_scaling_factor = self.rope_scaling.get("factor", None) | ||
if rope_scaling_name is None or rope_scaling_name not in ["linear", "ntk", "dynamic"]: | ||
raise ValueError( | ||
f"`rope_scaling`'s name field must be one of ['linear', 'ntk', 'dynamic'], got {rope_scaling_name}" | ||
) | ||
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: | ||
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") | ||
gante marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
super().__init__( | ||
pad_token_id=pad_token_id, | ||
bos_token_id=bos_token_id, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would put "Experimental feature" as a warning at the end of the docstring.
This is an experimental feature, subject to breaking API changes in future versions.