From e7d743fd955cc4c47052eae8aa258515191ece66 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 4 Jul 2023 17:05:01 +0000 Subject: [PATCH 1/8] add rope_scaling --- .../models/llama/configuration_llama.py | 25 +++++++++ .../models/llama/modeling_llama.py | 51 ++++++++++++++----- .../open_llama/configuration_open_llama.py | 24 +++++++++ .../models/open_llama/modeling_open_llama.py | 51 ++++++++++++++----- 4 files changed, 123 insertions(+), 28 deletions(-) diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index e3d075310e84..fb489f31a936 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -64,6 +64,13 @@ class LlamaConfig(PretrainedConfig): relevant if `config.is_decoder=True`. tie_word_embeddings(`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings + rope_scaling (`Dict`, *optional*): + Experimental feature -- dictionary containing the scaling configuration for the RoPE embeddings. Currently + supports three scaling strategies: linear, ntk, and dynamic. Their scaling factor must be an float greater + than 1. The expected format is `{"name": strategy name, "factor": scaling factor}`. See the following + thread for more information on how these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/ + Example: ```python @@ -97,6 +104,7 @@ def __init__( bos_token_id=1, eos_token_id=2, tie_word_embeddings=False, + rope_scaling=None, **kwargs, ): self.vocab_size = vocab_size @@ -109,6 +117,23 @@ def __init__( self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache + self.rope_scaling = rope_scaling + + # RoPE scaling validation + if self.rope_scaling is not None: + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + f"`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, got {self.rope_scaling}" + ) + rope_scaling_name = self.rope_scaling.get("name", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_name is None or rope_scaling_name not in ["linear", "ntk", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s name field must be one of ['linear', 'ntk', 'dynamic'], got {rope_scaling_name}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") + super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 24231c3f777d..2d5f363290f7 100755 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -89,32 +89,51 @@ def forward(self, hidden_states): class LlamaRotaryEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, rope_scaling=None): super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.scaling_name = rope_scaling.get("name", "") if rope_scaling is not None else "" + self.scaling_factor = rope_scaling.get("factor") if rope_scaling is not None else None + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + if self.scaling_name == "ntk": + self.base = self.base * self.scaling_factor ** (self.dim / (self.dim - 2)) + + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq) # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if self.scaling_name == "dynamic": + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + if self.scaling_name == "linear": + t = t / self.scaling_factor + freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - dtype = torch.get_default_dtype() self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False) + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + return ( self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), @@ -176,7 +195,11 @@ def __init__(self, config: LlamaConfig): self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) + self.rotary_emb = LlamaRotaryEmbedding( + dim=self.head_dim, + max_position_embeddings=self.max_position_embeddings, + rope_scaling=config.rope_scaling, + ) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() diff --git a/src/transformers/models/open_llama/configuration_open_llama.py b/src/transformers/models/open_llama/configuration_open_llama.py index cbde4d67d498..a097b52f09af 100644 --- a/src/transformers/models/open_llama/configuration_open_llama.py +++ b/src/transformers/models/open_llama/configuration_open_llama.py @@ -67,6 +67,12 @@ class OpenLlamaConfig(PretrainedConfig): relevant if `config.is_decoder=True`. tie_word_embeddings(`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings + rope_scaling (`Dict`, *optional*): + Experimental feature -- dictionary containing the scaling configuration for the RoPE embeddings. Currently + supports three scaling strategies: linear, ntk, and dynamic. Their scaling factor must be an float greater + than 1. The expected format is `{"name": strategy name, "factor": scaling factor}`. See the following + thread for more information on how these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/ Example: ```python @@ -104,6 +110,7 @@ def __init__( attention_dropout_prob=0.1, use_stable_embedding=True, shared_input_output_embedding=True, + rope_scaling=None, **kwargs, ): self.vocab_size = vocab_size @@ -123,6 +130,23 @@ def __init__( self.attention_dropout_prob = attention_dropout_prob self.use_stable_embedding = use_stable_embedding self.shared_input_output_embedding = shared_input_output_embedding + self.rope_scaling = rope_scaling + + # RoPE scaling validation + if self.rope_scaling is not None: + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + f"`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, got {self.rope_scaling}" + ) + rope_scaling_name = self.rope_scaling.get("name", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_name is None or rope_scaling_name not in ["linear", "ntk", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s name field must be one of ['linear', 'ntk', 'dynamic'], got {rope_scaling_name}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") + super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, diff --git a/src/transformers/models/open_llama/modeling_open_llama.py b/src/transformers/models/open_llama/modeling_open_llama.py index 84d5c6e78fa2..ed91684846ea 100644 --- a/src/transformers/models/open_llama/modeling_open_llama.py +++ b/src/transformers/models/open_llama/modeling_open_llama.py @@ -100,32 +100,51 @@ def forward(self, hidden_states): # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->OpenLlama class OpenLlamaRotaryEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, rope_scaling=None): super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.scaling_name = rope_scaling.get("name", "") if rope_scaling is not None else "" + self.scaling_factor = rope_scaling.get("factor") if rope_scaling is not None else None + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + if self.scaling_name == "ntk": + self.base = self.base * self.scaling_factor ** (self.dim / (self.dim - 2)) + + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq) # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if self.scaling_name == "dynamic": + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + if self.scaling_name == "linear": + t = t / self.scaling_factor + freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - dtype = torch.get_default_dtype() self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False) + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + return ( self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), @@ -190,7 +209,11 @@ def __init__(self, config: OpenLlamaConfig): self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self.rotary_emb = OpenLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) + self.rotary_emb = OpenLlamaRotaryEmbedding( + dim=self.head_dim, + max_position_embeddings=self.max_position_embeddings, + rope_scaling=config.rope_scaling, + ) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() From 11a25dae4cc070b18efd58e69a58620ff49a96f0 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 11 Jul 2023 17:54:49 +0000 Subject: [PATCH 2/8] tmp commit --- .../models/llama/configuration_llama.py | 38 ++++---- .../models/llama/modeling_llama.py | 93 ++++++++++++++----- .../open_llama/configuration_open_llama.py | 28 +++++- .../models/open_llama/modeling_open_llama.py | 74 +++++++++++---- 4 files changed, 175 insertions(+), 58 deletions(-) diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index fb489f31a936..c3948b0e67f8 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -65,11 +65,12 @@ class LlamaConfig(PretrainedConfig): tie_word_embeddings(`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings rope_scaling (`Dict`, *optional*): - Experimental feature -- dictionary containing the scaling configuration for the RoPE embeddings. Currently + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling strategies: linear, ntk, and dynamic. Their scaling factor must be an float greater - than 1. The expected format is `{"name": strategy name, "factor": scaling factor}`. See the following + than 1. The expected format is `{"type": strategy name, "factor": scaling factor}`. See the following thread for more information on how these scaling strategies behave: - https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/ + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. + This is an experimental feature, subject to breaking API changes in future versions. Example: @@ -118,26 +119,31 @@ def __init__( self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_scaling = rope_scaling + self._rope_scaling_validation() - # RoPE scaling validation + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ if self.rope_scaling is not None: if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: raise ValueError( - f"`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, got {self.rope_scaling}" + "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " + f"got {self.rope_scaling}" ) - rope_scaling_name = self.rope_scaling.get("name", None) + rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_name is None or rope_scaling_name not in ["linear", "ntk", "dynamic"]: + if rope_scaling_type is None or rope_scaling_type not in ["linear", "ntk", "dynamic"]: raise ValueError( - f"`rope_scaling`'s name field must be one of ['linear', 'ntk', 'dynamic'], got {rope_scaling_name}" + f"`rope_scaling`'s name field must be one of ['linear', 'ntk', 'dynamic'], got {rope_scaling_type}" ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") - - super().__init__( - pad_token_id=pad_token_id, - bos_token_id=bos_token_id, - eos_token_id=eos_token_id, - tie_word_embeddings=tie_word_embeddings, - **kwargs, - ) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 2d5f363290f7..03c9ef46e351 100755 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -89,17 +89,12 @@ def forward(self, hidden_states): class LlamaRotaryEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, rope_scaling=None): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() - self.scaling_name = rope_scaling.get("name", "") if rope_scaling is not None else "" - self.scaling_factor = rope_scaling.get("factor") if rope_scaling is not None else None self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - if self.scaling_name == "ntk": - self.base = self.base * self.scaling_factor ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq) @@ -110,19 +105,8 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, r def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - - if self.scaling_name == "dynamic": - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq) - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - if self.scaling_name == "linear": - t = t / self.scaling_factor - freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) @@ -139,6 +123,56 @@ def forward(self, x, seq_len=None): self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), ) +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + +class LlamaNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with NTK scaling. Credits to the Reddit user /u/bloc97""" + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + base = base * scaling_factor ** (dim / (dim - 2)) + super().__init__(dim, max_position_embeddings, base, device) + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit user /u/emozilla""" + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -195,11 +229,26 @@ def __init__(self, config: LlamaConfig): self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self.rotary_emb = LlamaRotaryEmbedding( - dim=self.head_dim, - max_position_embeddings=self.max_position_embeddings, - rope_scaling=config.rope_scaling, - ) + + if config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) + else: + scaling_type = config.rope_scaling["type"] + scaling_factor = config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor + ) + elif scaling_type == "ntk": + self.rotary_emb = LlamaNTKScalingRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() diff --git a/src/transformers/models/open_llama/configuration_open_llama.py b/src/transformers/models/open_llama/configuration_open_llama.py index a097b52f09af..d63c6042ce11 100644 --- a/src/transformers/models/open_llama/configuration_open_llama.py +++ b/src/transformers/models/open_llama/configuration_open_llama.py @@ -68,11 +68,12 @@ class OpenLlamaConfig(PretrainedConfig): tie_word_embeddings(`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings rope_scaling (`Dict`, *optional*): - Experimental feature -- dictionary containing the scaling configuration for the RoPE embeddings. Currently + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling strategies: linear, ntk, and dynamic. Their scaling factor must be an float greater - than 1. The expected format is `{"name": strategy name, "factor": scaling factor}`. See the following + than 1. The expected format is `{"type": strategy name, "factor": scaling factor}`. See the following thread for more information on how these scaling strategies behave: - https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/ + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. + This is an experimental feature, subject to breaking API changes in future versions. Example: ```python @@ -131,6 +132,7 @@ def __init__( self.use_stable_embedding = use_stable_embedding self.shared_input_output_embedding = shared_input_output_embedding self.rope_scaling = rope_scaling + self._rope_scaling_validation() # RoPE scaling validation if self.rope_scaling is not None: @@ -154,3 +156,23 @@ def __init__( tie_word_embeddings=tie_word_embeddings, **kwargs, ) + + #Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is not None: + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "ntk", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s name field must be one of ['linear', 'ntk', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") diff --git a/src/transformers/models/open_llama/modeling_open_llama.py b/src/transformers/models/open_llama/modeling_open_llama.py index ed91684846ea..6c0d9bb6c65c 100644 --- a/src/transformers/models/open_llama/modeling_open_llama.py +++ b/src/transformers/models/open_llama/modeling_open_llama.py @@ -100,17 +100,12 @@ def forward(self, hidden_states): # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->OpenLlama class OpenLlamaRotaryEmbedding(torch.nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, rope_scaling=None): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() - self.scaling_name = rope_scaling.get("name", "") if rope_scaling is not None else "" - self.scaling_factor = rope_scaling.get("factor") if rope_scaling is not None else None self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - if self.scaling_name == "ntk": - self.base = self.base * self.scaling_factor ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq) @@ -121,19 +116,8 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, r def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - - if self.scaling_name == "dynamic": - base = self.base * ( - (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) - ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq) - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) - if self.scaling_name == "linear": - t = t / self.scaling_factor - freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) @@ -150,6 +134,62 @@ def forward(self, x, seq_len=None): self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), ) +# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->OpenLlama +class LlamaLinearScalingRotaryEmbedding(OpenLlamaRotaryEmbedding): + """OpenLlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + +# Copied from transformers.models.llama.modeling_llama.LlamaNTKScalingRotaryEmbedding with Llama->OpenLlama +class LlamaNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding): + """OpenLlamaRotaryEmbedding extended with NTK scaling. Credits to the Reddit user /u/bloc97""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + base = base * scaling_factor ** (dim / (dim - 2)) + super().__init__(dim, max_position_embeddings, base, device) + + +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->OpenLlama +class LlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding): + """OpenLlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit user /u/emozilla""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + def rotate_half(x): """Rotates half the hidden dims of the input.""" From 022c977e0ea16d3c1933c6d809f5d3a83393e66b Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 12 Jul 2023 14:16:06 +0000 Subject: [PATCH 3/8] add gptneox --- .../models/gpt_neox/configuration_gpt_neox.py | 32 +++++ .../models/gpt_neox/modeling_gpt_neox.py | 119 +++++++++++++++--- .../modeling_gpt_neox_japanese.py | 27 ++-- .../models/llama/configuration_llama.py | 12 +- .../models/llama/modeling_llama.py | 12 +- .../open_llama/configuration_open_llama.py | 14 +-- .../models/open_llama/modeling_open_llama.py | 37 ++++-- 7 files changed, 202 insertions(+), 51 deletions(-) diff --git a/src/transformers/models/gpt_neox/configuration_gpt_neox.py b/src/transformers/models/gpt_neox/configuration_gpt_neox.py index 000d566398e4..432acc70a48e 100644 --- a/src/transformers/models/gpt_neox/configuration_gpt_neox.py +++ b/src/transformers/models/gpt_neox/configuration_gpt_neox.py @@ -73,6 +73,14 @@ class GPTNeoXConfig(PretrainedConfig): use_parallel_residual (`bool`, *optional*, defaults to `True`): Whether to use a "parallel" formulation in each Transformer layer, which can provide a slight training speedup at large scales (e.g. 20B). + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling + strategies: linear, ntk, and dynamic. Their scaling factor must be an float greater than 1. The expected + format is `{"type": strategy name, "factor": scaling factor}`. See the following thread for more + information on how these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + Example: ```python @@ -108,6 +116,7 @@ def __init__( eos_token_id=2, tie_word_embeddings=False, use_parallel_residual=True, + rope_scaling=None, **kwargs, ): super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) @@ -126,7 +135,30 @@ def __init__( self.use_cache = use_cache self.tie_word_embeddings = tie_word_embeddings self.use_parallel_residual = use_parallel_residual + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + if self.hidden_size % self.num_attention_heads != 0: raise ValueError( "The hidden size is not divisble by the number of attention heads! Make sure to update them!" ) + + # Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is not None: + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "ntk", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s name field must be one of ['linear', 'ntk', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 841cbe1aa8f2..afb47d0c7cf6 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -86,6 +86,7 @@ def _set_gradient_checkpointing(self, module, value=False): class GPTNeoXAttention(nn.Module): def __init__(self, config): super().__init__() + self.config = config self.num_attention_heads = config.num_attention_heads self.hidden_size = config.hidden_size if self.hidden_size % self.num_attention_heads != 0: @@ -103,9 +104,8 @@ def __init__(self, config): persistent=False, ) self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False) - self.rotary_emb = RotaryEmbedding( - self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base - ) + self._init_rope() + self.register_buffer( "norm_factor", torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()), @@ -114,6 +114,38 @@ def __init__(self, config): self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size) + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = GPTNeoXRotaryEmbedding( + self.rotary_ndims, self.config.max_position_embeddings, base=self.config.rotary_emb_base + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = GPTNeoXLinearScalingRotaryEmbedding( + self.rotary_ndims, + self.config.max_position_embeddings, + base=self.config.rotary_emb_base, + scaling_factor=scaling_factor, + ) + elif scaling_type == "ntk": + self.rotary_emb = GPTNeoXNTKScalingRotaryEmbedding( + self.rotary_ndims, + self.config.max_position_embeddings, + base=self.config.rotary_emb_base, + scaling_factor=scaling_factor, + ) + elif scaling_type == "dynamic": + self.rotary_emb = GPTNeoXDynamicNTKScalingRotaryEmbedding( + self.rotary_ndims, + self.config.max_position_embeddings, + base=self.config.rotary_emb_base, + scaling_factor=scaling_factor, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + def forward( self, hidden_states: torch.FloatTensor, @@ -254,15 +286,25 @@ def attention_mask_func(attention_scores, ltor_mask): return attention_scores -class RotaryEmbedding(torch.nn.Module): +class GPTNeoXRotaryEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings, base=10000, device=None): super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq) # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) @@ -271,18 +313,65 @@ def __init__(self, dim, max_position_embeddings, base=10000, device=None): def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self.cos_cached = emb.cos()[None, None, :, :] - self.sin_cached = emb.sin()[None, None, :, :] + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device) +class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): + """GPTNeoXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, dim, max_position_embeddings, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos()[None, None, :, :] + self.sin_cached = emb.sin()[None, None, :, :] + + +class GPTNeoXNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): + """GPTNeoXRotaryEmbedding extended with NTK scaling. Credits to the Reddit user /u/bloc97""" + + def __init__(self, dim, max_position_embeddings, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + base = base * scaling_factor ** (dim / (dim - 2)) + super().__init__(dim, max_position_embeddings, base, device) + + +class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): + """GPTNeoXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit user /u/emozilla""" + + def __init__(self, dim, max_position_embeddings, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.cos_cached = emb.cos()[None, None, :, :] + self.sin_cached = emb.sin()[None, None, :, :] + + def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index e7cb510e6222..7f04e61ffef7 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -238,16 +238,26 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): return attn_output, attn_weights -# Copied from transformers.models.gpt_neox.modeling_gpt_neox.RotaryEmbedding +# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding class RotaryEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings, base=10000, device=None): super().__init__() - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq) # Build here to make `torch.jit.trace` work. - self.max_seq_len_cached = max_position_embeddings - t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) @@ -256,15 +266,8 @@ def __init__(self, dim, max_position_embeddings, base=10000, device=None): def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] - # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. if seq_len > self.max_seq_len_cached: - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self.cos_cached = emb.cos()[None, None, :, :] - self.sin_cached = emb.sin()[None, None, :, :] + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device) diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index c3948b0e67f8..7d3cce542185 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -65,12 +65,12 @@ class LlamaConfig(PretrainedConfig): tie_word_embeddings(`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. Currently - supports three scaling strategies: linear, ntk, and dynamic. Their scaling factor must be an float greater - than 1. The expected format is `{"type": strategy name, "factor": scaling factor}`. See the following - thread for more information on how these scaling strategies behave: - https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. - This is an experimental feature, subject to breaking API changes in future versions. + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling + strategies: linear, ntk, and dynamic. Their scaling factor must be an float greater than 1. The expected + format is `{"type": strategy name, "factor": scaling factor}`. See the following thread for more + information on how these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. Example: diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 03c9ef46e351..b59bedc5056a 100755 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -123,8 +123,10 @@ def forward(self, x, seq_len=None): self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), ) + class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device) @@ -143,6 +145,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): class LlamaNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): """LlamaRotaryEmbedding extended with NTK scaling. Credits to the Reddit user /u/bloc97""" + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor base = base * scaling_factor ** (dim / (dim - 2)) @@ -151,6 +154,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit user /u/emozilla""" + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device) @@ -229,12 +233,14 @@ def __init__(self, config: LlamaConfig): self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self._init_rope() - if config.rope_scaling is None: + def _init_rope(self): + if self.config.rope_scaling is None: self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) else: - scaling_type = config.rope_scaling["type"] - scaling_factor = config.rope_scaling["factor"] + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] if scaling_type == "linear": self.rotary_emb = LlamaLinearScalingRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor diff --git a/src/transformers/models/open_llama/configuration_open_llama.py b/src/transformers/models/open_llama/configuration_open_llama.py index d63c6042ce11..dd27b56d90a5 100644 --- a/src/transformers/models/open_llama/configuration_open_llama.py +++ b/src/transformers/models/open_llama/configuration_open_llama.py @@ -68,12 +68,12 @@ class OpenLlamaConfig(PretrainedConfig): tie_word_embeddings(`bool`, *optional*, defaults to `False`): Whether to tie weight embeddings rope_scaling (`Dict`, *optional*): - Dictionary containing the scaling configuration for the RoPE embeddings. Currently - supports three scaling strategies: linear, ntk, and dynamic. Their scaling factor must be an float greater - than 1. The expected format is `{"type": strategy name, "factor": scaling factor}`. See the following - thread for more information on how these scaling strategies behave: - https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. - This is an experimental feature, subject to breaking API changes in future versions. + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling + strategies: linear, ntk, and dynamic. Their scaling factor must be an float greater than 1. The expected + format is `{"type": strategy name, "factor": scaling factor}`. See the following thread for more + information on how these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. Example: ```python @@ -157,7 +157,7 @@ def __init__( **kwargs, ) - #Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation + # Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation def _rope_scaling_validation(self): """ Validate the `rope_scaling` configuration. diff --git a/src/transformers/models/open_llama/modeling_open_llama.py b/src/transformers/models/open_llama/modeling_open_llama.py index 6c0d9bb6c65c..fabae83ab1a1 100644 --- a/src/transformers/models/open_llama/modeling_open_llama.py +++ b/src/transformers/models/open_llama/modeling_open_llama.py @@ -134,8 +134,9 @@ def forward(self, x, seq_len=None): self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), ) + # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->OpenLlama -class LlamaLinearScalingRotaryEmbedding(OpenLlamaRotaryEmbedding): +class OpenLlamaLinearScalingRotaryEmbedding(OpenLlamaRotaryEmbedding): """OpenLlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): @@ -155,7 +156,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): # Copied from transformers.models.llama.modeling_llama.LlamaNTKScalingRotaryEmbedding with Llama->OpenLlama -class LlamaNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding): +class OpenLlamaNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding): """OpenLlamaRotaryEmbedding extended with NTK scaling. Credits to the Reddit user /u/bloc97""" def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): @@ -165,7 +166,7 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->OpenLlama -class LlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding): +class OpenLlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding): """OpenLlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit user /u/emozilla""" def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): @@ -249,11 +250,31 @@ def __init__(self, config: OpenLlamaConfig): self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) - self.rotary_emb = OpenLlamaRotaryEmbedding( - dim=self.head_dim, - max_position_embeddings=self.max_position_embeddings, - rope_scaling=config.rope_scaling, - ) + self._init_rope() + + # Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->OpenLlama + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = OpenLlamaRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = OpenLlamaLinearScalingRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor + ) + elif scaling_type == "ntk": + self.rotary_emb = OpenLlamaNTKScalingRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor + ) + elif scaling_type == "dynamic": + self.rotary_emb = OpenLlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() From c12aaf114b729480f4bab287e980a5b0b5105dec Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 12 Jul 2023 15:56:18 +0000 Subject: [PATCH 4/8] add tests --- .../models/gpt_neox/configuration_gpt_neox.py | 5 ++- .../models/gpt_neox/modeling_gpt_neox.py | 30 +++++++++------ .../modeling_gpt_neox_japanese.py | 6 +-- .../models/llama/configuration_llama.py | 5 ++- .../open_llama/configuration_open_llama.py | 6 ++- .../models/gpt_neox/test_modeling_gpt_neox.py | 37 ++++++++++++++++++- tests/models/llama/test_modeling_llama.py | 35 +++++++++++++++++- .../open_llama/test_modeling_open_llama.py | 35 +++++++++++++++++- 8 files changed, 133 insertions(+), 26 deletions(-) diff --git a/src/transformers/models/gpt_neox/configuration_gpt_neox.py b/src/transformers/models/gpt_neox/configuration_gpt_neox.py index 432acc70a48e..a43bc15d30df 100644 --- a/src/transformers/models/gpt_neox/configuration_gpt_neox.py +++ b/src/transformers/models/gpt_neox/configuration_gpt_neox.py @@ -76,8 +76,9 @@ class GPTNeoXConfig(PretrainedConfig): rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling strategies: linear, ntk, and dynamic. Their scaling factor must be an float greater than 1. The expected - format is `{"type": strategy name, "factor": scaling factor}`. See the following thread for more - information on how these scaling strategies behave: + format is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an experimental feature, subject to breaking API changes in future versions. diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index afb47d0c7cf6..ad262682f5f7 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -95,14 +95,8 @@ def __init__(self, config): ) self.head_size = self.hidden_size // self.num_attention_heads self.rotary_ndims = int(self.head_size * config.rotary_pct) - max_positions = config.max_position_embeddings - self.register_buffer( - "bias", - torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( - 1, 1, max_positions, max_positions - ), - persistent=False, - ) + self._init_bias(config.max_position_embeddings) + self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False) self._init_rope() @@ -114,6 +108,17 @@ def __init__(self, config): self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size) + def _init_bias(self, max_positions, device=None): + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( + 1, 1, max_positions, max_positions + ), + persistent=False, + ) + if device is not None: + self.bias = self.bias.to(device) + def _init_rope(self): if self.config.rope_scaling is None: self.rotary_emb = GPTNeoXRotaryEmbedding( @@ -240,6 +245,9 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): batch_size, num_attention_heads, query_length, attn_head_size = query.size() key_length = key.size(-2) + # dynamically increase the causal mask with the key length, if needed. + if key_length > self.bias.shape[-1]: + self._init_bias(key_length, device=key.device) causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] query = query.view(batch_size * num_attention_heads, query_length, attn_head_size) @@ -297,9 +305,7 @@ def __init__(self, dim, max_position_embeddings, base=10000, device=None): self.register_buffer("inv_freq", inv_freq) # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) + self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device) def _set_cos_sin_cache(self, seq_len, device): self.max_seq_len_cached = seq_len @@ -314,7 +320,7 @@ def _set_cos_sin_cache(self, seq_len, device): def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + self._set_cos_sin_cache(seq_len=seq_len, device=x.device) return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device) diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 7f04e61ffef7..7df64174a599 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -250,9 +250,7 @@ def __init__(self, dim, max_position_embeddings, base=10000, device=None): self.register_buffer("inv_freq", inv_freq) # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) + self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device) def _set_cos_sin_cache(self, seq_len, device): self.max_seq_len_cached = seq_len @@ -267,7 +265,7 @@ def _set_cos_sin_cache(self, seq_len, device): def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + self._set_cos_sin_cache(seq_len=seq_len, device=x.device) return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device) diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index 7d3cce542185..423612e1665c 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -67,8 +67,9 @@ class LlamaConfig(PretrainedConfig): rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling strategies: linear, ntk, and dynamic. Their scaling factor must be an float greater than 1. The expected - format is `{"type": strategy name, "factor": scaling factor}`. See the following thread for more - information on how these scaling strategies behave: + format is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an experimental feature, subject to breaking API changes in future versions. diff --git a/src/transformers/models/open_llama/configuration_open_llama.py b/src/transformers/models/open_llama/configuration_open_llama.py index dd27b56d90a5..6b4bb3dab98f 100644 --- a/src/transformers/models/open_llama/configuration_open_llama.py +++ b/src/transformers/models/open_llama/configuration_open_llama.py @@ -70,10 +70,12 @@ class OpenLlamaConfig(PretrainedConfig): rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling strategies: linear, ntk, and dynamic. Their scaling factor must be an float greater than 1. The expected - format is `{"type": strategy name, "factor": scaling factor}`. See the following thread for more - information on how these scaling strategies behave: + format is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an experimental feature, subject to breaking API changes in future versions. + Example: ```python diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py index ed9b5764a361..3faf23f28df6 100644 --- a/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -17,7 +17,9 @@ import unittest -from transformers import AutoTokenizer, GPTNeoXConfig, is_torch_available +from parameterized import parameterized + +from transformers import AutoTokenizer, GPTNeoXConfig, is_torch_available, set_seed from transformers.testing_utils import require_torch, slow, torch_device from ...generation.test_utils import GenerationTesterMixin @@ -49,7 +51,7 @@ def __init__( use_token_type_ids=True, use_labels=True, vocab_size=99, - hidden_size=32, + hidden_size=64, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37, @@ -298,6 +300,37 @@ def test_model_for_token_classification(self): def test_feed_forward_chunking(self): pass + @parameterized.expand([("linear",), ("ntk",), ("dynamic",)]) + def test_model_rope_scaling(self, scaling_type): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + short_input = ids_tensor([1, 10], config.vocab_size) + long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + original_model = GPTNeoXModel(config) + original_model.to(torch_device) + original_model.eval() + original_short_output = original_model(short_input).last_hidden_state + original_long_output = original_model(long_input).last_hidden_state + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + config.rope_scaling = {"type": scaling_type, "factor": 10.0} + scaled_model = GPTNeoXModel(config) + scaled_model.to(torch_device) + scaled_model.eval() + scaled_short_output = scaled_model(short_input).last_hidden_state + scaled_long_output = scaled_model(long_input).last_hidden_state + + # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original + # maximum sequence length, so the outputs for the short input should match. + if scaling_type == "dynamic": + self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + else: + self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + + # The output should be different for long inputs + self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) + @require_torch class GPTNeoXLanguageGenerationTest(unittest.TestCase): diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 113f74d30976..7138cefa9a39 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -17,7 +17,9 @@ import unittest -from transformers import LlamaConfig, is_torch_available +from parameterized import parameterized + +from transformers import LlamaConfig, is_torch_available, set_seed from transformers.testing_utils import require_torch, torch_device from ...generation.test_utils import GenerationTesterMixin @@ -332,3 +334,34 @@ def test_llama_sequence_classification_model_for_multi_label(self): @unittest.skip("LLaMA buffers include complex numbers, which breaks this test") def test_save_load_fast_init_from_base(self): pass + + @parameterized.expand([("linear",), ("ntk",), ("dynamic",)]) + def test_model_rope_scaling(self, scaling_type): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + short_input = ids_tensor([1, 10], config.vocab_size) + long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + original_model = LlamaModel(config) + original_model.to(torch_device) + original_model.eval() + original_short_output = original_model(short_input).last_hidden_state + original_long_output = original_model(long_input).last_hidden_state + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + config.rope_scaling = {"type": scaling_type, "factor": 10.0} + scaled_model = LlamaModel(config) + scaled_model.to(torch_device) + scaled_model.eval() + scaled_short_output = scaled_model(short_input).last_hidden_state + scaled_long_output = scaled_model(long_input).last_hidden_state + + # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original + # maximum sequence length, so the outputs for the short input should match. + if scaling_type == "dynamic": + self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + else: + self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + + # The output should be different for long inputs + self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) diff --git a/tests/models/open_llama/test_modeling_open_llama.py b/tests/models/open_llama/test_modeling_open_llama.py index 6cdf61a8729d..e7d05166b8b0 100644 --- a/tests/models/open_llama/test_modeling_open_llama.py +++ b/tests/models/open_llama/test_modeling_open_llama.py @@ -17,7 +17,9 @@ import unittest -from transformers import OpenLlamaConfig, is_torch_available +from parameterized import parameterized + +from transformers import OpenLlamaConfig, is_torch_available, set_seed from transformers.testing_utils import require_torch, torch_device from ...generation.test_utils import GenerationTesterMixin @@ -335,3 +337,34 @@ def test_open_llama_sequence_classification_model_for_multi_label(self): @unittest.skip("Open-Llama buffers include complex numbers, which breaks this test") def test_save_load_fast_init_from_base(self): pass + + @parameterized.expand([("linear",), ("ntk",), ("dynamic",)]) + def test_model_rope_scaling(self, scaling_type): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + short_input = ids_tensor([1, 10], config.vocab_size) + long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size) + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + original_model = OpenLlamaModel(config) + original_model.to(torch_device) + original_model.eval() + original_short_output = original_model(short_input).last_hidden_state + original_long_output = original_model(long_input).last_hidden_state + + set_seed(42) # Fixed seed at init time so the two models get the same random weights + config.rope_scaling = {"type": scaling_type, "factor": 10.0} + scaled_model = OpenLlamaModel(config) + scaled_model.to(torch_device) + scaled_model.eval() + scaled_short_output = scaled_model(short_input).last_hidden_state + scaled_long_output = scaled_model(long_input).last_hidden_state + + # Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original + # maximum sequence length, so the outputs for the short input should match. + if scaling_type == "dynamic": + self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + else: + self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5)) + + # The output should be different for long inputs + self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) From 8c525244e3b4fa891622b00ca530c597eff6a35c Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 12 Jul 2023 17:48:42 +0000 Subject: [PATCH 5/8] GPTNeoX can now handle long inputs, so the pipeline test was wrong --- tests/pipelines/test_pipelines_text_generation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/test_pipelines_text_generation.py b/tests/pipelines/test_pipelines_text_generation.py index b4605164aefd..44a29a673d81 100644 --- a/tests/pipelines/test_pipelines_text_generation.py +++ b/tests/pipelines/test_pipelines_text_generation.py @@ -240,7 +240,7 @@ def run_pipeline_test(self, text_generator, _): # We don't care about infinite range models. # They already work. # Skip this test for XGLM, since it uses sinusoidal positional embeddings which are resized on-the-fly. - EXTRA_MODELS_CAN_HANDLE_LONG_INPUTS = ["RwkvForCausalLM", "XGLMForCausalLM"] + EXTRA_MODELS_CAN_HANDLE_LONG_INPUTS = ["RwkvForCausalLM", "XGLMForCausalLM", "GPTNeoXForCausalLM"] if ( tokenizer.model_max_length < 10000 and text_generator.model.__class__.__name__ not in EXTRA_MODELS_CAN_HANDLE_LONG_INPUTS From c7161b45bff55601f3b257d1cf836c2ad23bd7fe Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 13 Jul 2023 12:03:46 +0100 Subject: [PATCH 6/8] Update src/transformers/models/open_llama/configuration_open_llama.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../open_llama/configuration_open_llama.py | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/open_llama/configuration_open_llama.py b/src/transformers/models/open_llama/configuration_open_llama.py index 6b4bb3dab98f..43fb556e7f89 100644 --- a/src/transformers/models/open_llama/configuration_open_llama.py +++ b/src/transformers/models/open_llama/configuration_open_llama.py @@ -164,17 +164,19 @@ def _rope_scaling_validation(self): """ Validate the `rope_scaling` configuration. """ - if self.rope_scaling is not None: - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: - raise ValueError( - "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " - f"got {self.rope_scaling}" - ) - rope_scaling_type = self.rope_scaling.get("type", None) - rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "ntk", "dynamic"]: - raise ValueError( - f"`rope_scaling`'s name field must be one of ['linear', 'ntk', 'dynamic'], got {rope_scaling_type}" - ) - if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "ntk", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s name field must be one of ['linear', 'ntk', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") From 7c02d922a225d7a7ce0e5eeb0fc78bb88a264830 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 13 Jul 2023 11:10:14 +0000 Subject: [PATCH 7/8] remove ntk --- .../models/gpt_neox/configuration_gpt_neox.py | 34 ++++++++++--------- .../models/gpt_neox/modeling_gpt_neox.py | 18 +--------- .../models/llama/configuration_llama.py | 34 ++++++++++--------- .../models/llama/modeling_llama.py | 15 +------- .../open_llama/configuration_open_llama.py | 12 +++---- .../models/open_llama/modeling_open_llama.py | 16 +-------- .../models/gpt_neox/test_modeling_gpt_neox.py | 2 +- tests/models/llama/test_modeling_llama.py | 2 +- .../open_llama/test_modeling_open_llama.py | 2 +- 9 files changed, 48 insertions(+), 87 deletions(-) diff --git a/src/transformers/models/gpt_neox/configuration_gpt_neox.py b/src/transformers/models/gpt_neox/configuration_gpt_neox.py index f47dee8421b6..24c4b4a3df03 100644 --- a/src/transformers/models/gpt_neox/configuration_gpt_neox.py +++ b/src/transformers/models/gpt_neox/configuration_gpt_neox.py @@ -80,8 +80,8 @@ class GPTNeoXConfig(PretrainedConfig): speedup at large scales (e.g. 20B). rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling - strategies: linear, ntk, and dynamic. Their scaling factor must be an float greater than 1. The expected - format is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format + is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update `max_position_embeddings` to the expected new maximum. See the following thread for more information on how these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an @@ -158,17 +158,19 @@ def _rope_scaling_validation(self): """ Validate the `rope_scaling` configuration. """ - if self.rope_scaling is not None: - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: - raise ValueError( - "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " - f"got {self.rope_scaling}" - ) - rope_scaling_type = self.rope_scaling.get("type", None) - rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "ntk", "dynamic"]: - raise ValueError( - f"`rope_scaling`'s name field must be one of ['linear', 'ntk', 'dynamic'], got {rope_scaling_type}" - ) - if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 45f53c7f8ebc..b1d694e61de1 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -135,13 +135,6 @@ def _init_rope(self): base=self.config.rotary_emb_base, scaling_factor=scaling_factor, ) - elif scaling_type == "ntk": - self.rotary_emb = GPTNeoXNTKScalingRotaryEmbedding( - self.rotary_ndims, - self.config.max_position_embeddings, - base=self.config.rotary_emb_base, - scaling_factor=scaling_factor, - ) elif scaling_type == "dynamic": self.rotary_emb = GPTNeoXDynamicNTKScalingRotaryEmbedding( self.rotary_ndims, @@ -346,17 +339,8 @@ def _set_cos_sin_cache(self, seq_len, device): self.sin_cached = emb.sin()[None, None, :, :] -class GPTNeoXNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): - """GPTNeoXRotaryEmbedding extended with NTK scaling. Credits to the Reddit user /u/bloc97""" - - def __init__(self, dim, max_position_embeddings, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - base = base * scaling_factor ** (dim / (dim - 2)) - super().__init__(dim, max_position_embeddings, base, device) - - class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding): - """GPTNeoXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit user /u/emozilla""" + """GPTNeoXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" def __init__(self, dim, max_position_embeddings, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor diff --git a/src/transformers/models/llama/configuration_llama.py b/src/transformers/models/llama/configuration_llama.py index 423612e1665c..d456b79e66fb 100644 --- a/src/transformers/models/llama/configuration_llama.py +++ b/src/transformers/models/llama/configuration_llama.py @@ -66,8 +66,8 @@ class LlamaConfig(PretrainedConfig): Whether to tie weight embeddings rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling - strategies: linear, ntk, and dynamic. Their scaling factor must be an float greater than 1. The expected - format is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format + is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update `max_position_embeddings` to the expected new maximum. See the following thread for more information on how these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an @@ -134,17 +134,19 @@ def _rope_scaling_validation(self): """ Validate the `rope_scaling` configuration. """ - if self.rope_scaling is not None: - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: - raise ValueError( - "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " - f"got {self.rope_scaling}" - ) - rope_scaling_type = self.rope_scaling.get("type", None) - rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "ntk", "dynamic"]: - raise ValueError( - f"`rope_scaling`'s name field must be one of ['linear', 'ntk', 'dynamic'], got {rope_scaling_type}" - ) - if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index f4b931834302..199078f29300 100755 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -143,17 +143,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) -class LlamaNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with NTK scaling. Credits to the Reddit user /u/bloc97""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - base = base * scaling_factor ** (dim / (dim - 2)) - super().__init__(dim, max_position_embeddings, base, device) - - class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): - """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit user /u/emozilla""" + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor @@ -245,10 +236,6 @@ def _init_rope(self): self.rotary_emb = LlamaLinearScalingRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor ) - elif scaling_type == "ntk": - self.rotary_emb = LlamaNTKScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor - ) elif scaling_type == "dynamic": self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor diff --git a/src/transformers/models/open_llama/configuration_open_llama.py b/src/transformers/models/open_llama/configuration_open_llama.py index 43fb556e7f89..d1b8c58a82f8 100644 --- a/src/transformers/models/open_llama/configuration_open_llama.py +++ b/src/transformers/models/open_llama/configuration_open_llama.py @@ -69,8 +69,8 @@ class OpenLlamaConfig(PretrainedConfig): Whether to tie weight embeddings rope_scaling (`Dict`, *optional*): Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling - strategies: linear, ntk, and dynamic. Their scaling factor must be an float greater than 1. The expected - format is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format + is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update `max_position_embeddings` to the expected new maximum. See the following thread for more information on how these scaling strategies behave: https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an @@ -165,8 +165,8 @@ def _rope_scaling_validation(self): Validate the `rope_scaling` configuration. """ if self.rope_scaling is None: - return - + return + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: raise ValueError( "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " @@ -174,9 +174,9 @@ def _rope_scaling_validation(self): ) rope_scaling_type = self.rope_scaling.get("type", None) rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_type is None or rope_scaling_type not in ["linear", "ntk", "dynamic"]: + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: raise ValueError( - f"`rope_scaling`'s name field must be one of ['linear', 'ntk', 'dynamic'], got {rope_scaling_type}" + f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" ) if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") diff --git a/src/transformers/models/open_llama/modeling_open_llama.py b/src/transformers/models/open_llama/modeling_open_llama.py index fabae83ab1a1..137c3e8fe86e 100644 --- a/src/transformers/models/open_llama/modeling_open_llama.py +++ b/src/transformers/models/open_llama/modeling_open_llama.py @@ -155,19 +155,9 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) -# Copied from transformers.models.llama.modeling_llama.LlamaNTKScalingRotaryEmbedding with Llama->OpenLlama -class OpenLlamaNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding): - """OpenLlamaRotaryEmbedding extended with NTK scaling. Credits to the Reddit user /u/bloc97""" - - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - base = base * scaling_factor ** (dim / (dim - 2)) - super().__init__(dim, max_position_embeddings, base, device) - - # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->OpenLlama class OpenLlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding): - """OpenLlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit user /u/emozilla""" + """OpenLlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor @@ -265,10 +255,6 @@ def _init_rope(self): self.rotary_emb = OpenLlamaLinearScalingRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor ) - elif scaling_type == "ntk": - self.rotary_emb = OpenLlamaNTKScalingRotaryEmbedding( - self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor - ) elif scaling_type == "dynamic": self.rotary_emb = OpenLlamaDynamicNTKScalingRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor diff --git a/tests/models/gpt_neox/test_modeling_gpt_neox.py b/tests/models/gpt_neox/test_modeling_gpt_neox.py index 3faf23f28df6..176970779ed5 100644 --- a/tests/models/gpt_neox/test_modeling_gpt_neox.py +++ b/tests/models/gpt_neox/test_modeling_gpt_neox.py @@ -300,7 +300,7 @@ def test_model_for_token_classification(self): def test_feed_forward_chunking(self): pass - @parameterized.expand([("linear",), ("ntk",), ("dynamic",)]) + @parameterized.expand([("linear",), ("dynamic",)]) def test_model_rope_scaling(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() short_input = ids_tensor([1, 10], config.vocab_size) diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 7138cefa9a39..e8b808461945 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -335,7 +335,7 @@ def test_llama_sequence_classification_model_for_multi_label(self): def test_save_load_fast_init_from_base(self): pass - @parameterized.expand([("linear",), ("ntk",), ("dynamic",)]) + @parameterized.expand([("linear",), ("dynamic",)]) def test_model_rope_scaling(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() short_input = ids_tensor([1, 10], config.vocab_size) diff --git a/tests/models/open_llama/test_modeling_open_llama.py b/tests/models/open_llama/test_modeling_open_llama.py index e7d05166b8b0..687b267b70ca 100644 --- a/tests/models/open_llama/test_modeling_open_llama.py +++ b/tests/models/open_llama/test_modeling_open_llama.py @@ -338,7 +338,7 @@ def test_open_llama_sequence_classification_model_for_multi_label(self): def test_save_load_fast_init_from_base(self): pass - @parameterized.expand([("linear",), ("ntk",), ("dynamic",)]) + @parameterized.expand([("linear",), ("dynamic",)]) def test_model_rope_scaling(self, scaling_type): config, _ = self.model_tester.prepare_config_and_inputs_for_common() short_input = ids_tensor([1, 10], config.vocab_size) From f385927bff8da7591ecdbd995249c05fdc9f673e Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 13 Jul 2023 11:11:43 +0000 Subject: [PATCH 8/8] remove redundant validation --- .../models/open_llama/configuration_open_llama.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/transformers/models/open_llama/configuration_open_llama.py b/src/transformers/models/open_llama/configuration_open_llama.py index d1b8c58a82f8..c0629b31e812 100644 --- a/src/transformers/models/open_llama/configuration_open_llama.py +++ b/src/transformers/models/open_llama/configuration_open_llama.py @@ -136,21 +136,6 @@ def __init__( self.rope_scaling = rope_scaling self._rope_scaling_validation() - # RoPE scaling validation - if self.rope_scaling is not None: - if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: - raise ValueError( - f"`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, got {self.rope_scaling}" - ) - rope_scaling_name = self.rope_scaling.get("name", None) - rope_scaling_factor = self.rope_scaling.get("factor", None) - if rope_scaling_name is None or rope_scaling_name not in ["linear", "ntk", "dynamic"]: - raise ValueError( - f"`rope_scaling`'s name field must be one of ['linear', 'ntk', 'dynamic'], got {rope_scaling_name}" - ) - if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: - raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") - super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id,