Skip to content

Transformer 4.45.0 support #35

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions QQQ/gptq/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,9 @@ def __init__(self, config: Qwen2Config, quant_config: dict):
self._attn_implementation = config._attn_implementation
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

# added for transformer>=4.45.0
self.rotary_emb = Qwen2RotaryEmbedding(config=config)

self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
Expand Down
5 changes: 1 addition & 4 deletions QQQ/smooth/migration/migration_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,7 @@ def qkv_function(self, input, weight, bias=None):
.view(B, N, self.extra_dict["num_key_value_heads"], head_dim)
.transpose(1, 2)
)
cos, sin = (
self.extra_dict["cos_cached"],
self.extra_dict["sin_cached"],
)
cos, sin = self.extra_dict['rotary_emb'](v, self.extra_dict["position_ids"])
q, k = qwen2.apply_rotary_pos_emb(
q, k, cos, sin, self.extra_dict["position_ids"]
)
Expand Down
33 changes: 28 additions & 5 deletions QQQ/smooth/models/qwen2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
""" PyTorch QuantizedLLaMA model."""
import warnings
import logging
from types import SimpleNamespace
from typing import List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -40,6 +41,14 @@ def __init__(
QuantizedModule.__init__(self, backend=backend)
self.w_qconfig = w_qconfig
self.a_qconfig = a_qconfig

if hasattr(org_module, 'config'):
self.config = org_module.config
else:
self.config = SimpleNamespace()
self.config.hidden_size = org_module.hidden_size
self.config.intermediate_size = org_module.intermediate_size

self.config = org_module.config
self.qinput = qinput
self.hidden_size = org_module.hidden_size
Expand Down Expand Up @@ -128,7 +137,8 @@ def __init__(
QuantizedModule.__init__(self, backend=backend)
self.w_qconfig = w_qconfig
self.a_qconfig = a_qconfig
self.config = org_module.config
if hasattr(org_module, 'config'):
self.config = org_module.config
self.qinput = qinput
self.layer_idx = org_module.layer_idx

Expand Down Expand Up @@ -178,6 +188,8 @@ def forward(
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
assert not output_attentions
Expand Down Expand Up @@ -209,11 +221,12 @@ def forward(
"num_heads": self.num_heads,
"num_key_value_heads": self.num_key_value_heads,
"num_key_value_groups": self.num_key_value_groups,
"cos_cached": self.rotary_emb.cos_cached,
"sin_cached": self.rotary_emb.sin_cached,
"rotary_emb": self.rotary_emb,
"head_dim": self.head_dim,
"position_ids": position_ids,
"attention_mask": attention_mask,
"cache_position": cache_position,
"position_embeddings": position_embeddings,
"observation_mask": observation_mask,
}
# update scale
Expand Down Expand Up @@ -249,10 +262,10 @@ def forward(
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin = self.rotary_emb(value_states, position_ids)

query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
query_states, key_states, cos, sin
)

if past_key_value is not None:
Expand Down Expand Up @@ -355,6 +368,8 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
Expand Down Expand Up @@ -432,6 +447,7 @@ def __init__(
self._attn_implementation = org_module._attn_implementation
# NOTE(HandH1998): Qwen2 fp16 is abnormal for `eager` attention, here we only support `sdpa`
assert self._attn_implementation == "sdpa"
self.rotary_emb = org_module.rotary_emb
self.norm = org_module.norm
self.gradient_checkpointing = False

Expand All @@ -447,6 +463,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
observation_mask: Optional[torch.Tensor] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions
Expand Down Expand Up @@ -543,6 +560,8 @@ def forward(

hidden_states = inputs_embeds

position_embeddings = self.rotary_emb(hidden_states, position_ids)

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
Expand All @@ -559,6 +578,8 @@ def forward(
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
observation_mask=observation_mask,
)

Expand Down Expand Up @@ -640,6 +661,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -698,6 +720,7 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
observation_mask=observation_mask,
cache_position=cache_position,
)

hidden_states = outputs[0]
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
transformers==4.38.2
transformers==4.45.0
datasets==2.16.1
easydict
accelerate
zstandard
lm_eval==0.4.2
lm_eval==0.4.2