Skip to content

Llama/GPTNeoX: add RoPE scaling #24653

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/transformers/models/llama/configuration_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ class LlamaConfig(PretrainedConfig):
relevant if `config.is_decoder=True`.
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_scaling (`Dict`, *optional*):
Experimental feature -- dictionary containing the scaling configuration for the RoPE embeddings. Currently
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would put "Experimental feature" as a warning at the end of the docstring.

This is an experimental feature, subject to breaking API changes in future versions.

supports three scaling strategies: linear, ntk, and dynamic. Their scaling factor must be an float greater
than 1. The expected format is `{"name": strategy name, "factor": scaling factor}`. See the following
thread for more information on how these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/

Example:

```python
Expand Down Expand Up @@ -97,6 +104,7 @@ def __init__(
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_scaling=None,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -109,6 +117,23 @@ def __init__(
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_scaling = rope_scaling

# RoPE scaling validation
if self.rope_scaling is not None:
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
f"`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, got {self.rope_scaling}"
)
rope_scaling_name = self.rope_scaling.get("name", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_name is None or rope_scaling_name not in ["linear", "ntk", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s name field must be one of ['linear', 'ntk', 'dynamic'], got {rope_scaling_name}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
Expand Down
51 changes: 37 additions & 14 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,32 +89,51 @@ def forward(self, hidden_states):


class LlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, rope_scaling=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.scaling_name = rope_scaling.get("name", "") if rope_scaling is not None else ""
self.scaling_factor = rope_scaling.get("factor") if rope_scaling is not None else None

self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
if self.scaling_name == "ntk":
self.base = self.base * self.scaling_factor ** (self.dim / (self.dim - 2))

inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len

if self.scaling_name == "dynamic":

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For dynamic, we want to do the scaling only when seq_len > max_position_embeddings (i.e. when we're going past the model's pre-trained length). My original code did this by just having the scaling in the forward() code that re-calculated the frequency cache when seq_len > self.max_seq_len_cached but not in the __init__. Since this code has now been deduplicated (makes sense!), I think this needs to be

if self.scaling_name == "dynamic" and seq_len > self.max_position_embeddings:

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jquesnelle that's a great catch, and completely missed in my attempts to unify and refactor the cos/sin cache!

base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)

t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

if self.scaling_name == "linear":
t = t / self.scaling_factor

freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
Expand Down Expand Up @@ -176,7 +195,11 @@ def __init__(self, config: LlamaConfig):
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
self.rotary_emb = LlamaRotaryEmbedding(
dim=self.head_dim,
max_position_embeddings=self.max_position_embeddings,
rope_scaling=config.rope_scaling,
)

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
Expand Down
24 changes: 24 additions & 0 deletions src/transformers/models/open_llama/configuration_open_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ class OpenLlamaConfig(PretrainedConfig):
relevant if `config.is_decoder=True`.
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_scaling (`Dict`, *optional*):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(changes in open_llama are copy/paste)

Experimental feature -- dictionary containing the scaling configuration for the RoPE embeddings. Currently
supports three scaling strategies: linear, ntk, and dynamic. Their scaling factor must be an float greater
than 1. The expected format is `{"name": strategy name, "factor": scaling factor}`. See the following
thread for more information on how these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/
Example:

```python
Expand Down Expand Up @@ -104,6 +110,7 @@ def __init__(
attention_dropout_prob=0.1,
use_stable_embedding=True,
shared_input_output_embedding=True,
rope_scaling=None,
**kwargs,
):
self.vocab_size = vocab_size
Expand All @@ -123,6 +130,23 @@ def __init__(
self.attention_dropout_prob = attention_dropout_prob
self.use_stable_embedding = use_stable_embedding
self.shared_input_output_embedding = shared_input_output_embedding
self.rope_scaling = rope_scaling

# RoPE scaling validation
if self.rope_scaling is not None:
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
f"`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, got {self.rope_scaling}"
)
rope_scaling_name = self.rope_scaling.get("name", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_name is None or rope_scaling_name not in ["linear", "ntk", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s name field must be one of ['linear', 'ntk', 'dynamic'], got {rope_scaling_name}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
Expand Down
51 changes: 37 additions & 14 deletions src/transformers/models/open_llama/modeling_open_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,32 +100,51 @@ def forward(self, hidden_states):

# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->OpenLlama
class OpenLlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, rope_scaling=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.scaling_name = rope_scaling.get("name", "") if rope_scaling is not None else ""
self.scaling_factor = rope_scaling.get("factor") if rope_scaling is not None else None

self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
if self.scaling_name == "ntk":
self.base = self.base * self.scaling_factor ** (self.dim / (self.dim - 2))

inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len

if self.scaling_name == "dynamic":
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)

t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)

if self.scaling_name == "linear":
t = t / self.scaling_factor

freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
Expand Down Expand Up @@ -190,7 +209,11 @@ def __init__(self, config: OpenLlamaConfig):
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.rotary_emb = OpenLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
self.rotary_emb = OpenLlamaRotaryEmbedding(
dim=self.head_dim,
max_position_embeddings=self.max_position_embeddings,
rope_scaling=config.rope_scaling,
)

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
Expand Down