-
Notifications
You must be signed in to change notification settings - Fork 146
[OpenVINO] Support Zamba2 by OpenVINO #1354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 18 commits
283403f
e6ef129
f535012
112af9a
708f52c
f956e00
ff1dbc6
32cfb33
6169f62
78b21de
191a3f4
018d81a
7be4c4b
c6ef767
3b97ea5
ff470f7
c906220
196827e
34a4ee3
f094f78
4c0ffc5
06ef4e0
56bff2e
a4e3bd0
7db344b
3aca613
0825f43
b11d517
04d1496
bd427b2
ef52983
74e4da7
f4712b3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -138,6 +138,7 @@ | |||||
QwenModelPatcher, | ||||||
SanaTextEncoderModelPatcher, | ||||||
XverseModelPatcher, | ||||||
Zamba2ModelPatcher, | ||||||
) | ||||||
|
||||||
|
||||||
|
@@ -4278,3 +4279,100 @@ class GPT2OpenVINOConfig(GPT2OnnxConfig): | |||||
) | ||||||
class VisionEncoderDecoderOpenVINOConfig(VisionEncoderDecoderOnnxConfig): | ||||||
_MODEL_PATCHER = OVSeq2SeqModelPatcher | ||||||
|
||||||
|
||||||
class Zamba2DummyInputGenerator(DummyInputGenerator): | ||||||
""" | ||||||
Generates dummy past_key_values inputs for Zamba2 architectures. | ||||||
""" | ||||||
|
||||||
SUPPORTED_INPUT_NAMES = ("position_ids", "cache_position", "past_key_values") | ||||||
|
||||||
def __init__( | ||||||
self, | ||||||
task: str, | ||||||
normalized_config, | ||||||
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], | ||||||
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], | ||||||
**kwargs, | ||||||
): | ||||||
config = normalized_config.config | ||||||
self.num_key_value_heads = normalized_config.num_key_value_heads | ||||||
self.intermediate_size = int(config.mamba_expand * config.hidden_size) | ||||||
self.ssm_state_size = config.mamba_d_state | ||||||
self.conv_kernel_size = config.mamba_d_conv | ||||||
self.n_mamba_heads = config.n_mamba_heads | ||||||
self.num_hidden_layers = config.num_hidden_layers | ||||||
self.mamba_ngroups = config.mamba_ngroups | ||||||
self.mamba_d_state = config.mamba_d_state | ||||||
self.batch_size = batch_size | ||||||
self.mamba_headdim = config.mamba_headdim | ||||||
self.head_dim = config.attention_head_dim | ||||||
self.num_attention_heads = config.num_attention_heads | ||||||
self.sequence_length = sequence_length | ||||||
|
||||||
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): | ||||||
if input_name == "past_key_values": | ||||||
past_key_values = [] | ||||||
# generate tuples of (key, value, conv_state, ssm_state) | ||||||
for i in range(self.num_hidden_layers): | ||||||
kv_shape = (self.batch_size, self.num_attention_heads, 1, self.head_dim) | ||||||
|
kv_shape = (self.batch_size, self.num_attention_heads, 1, self.head_dim) | |
kv_shape = (self.batch_size, self.num_attention_heads, self.sequence_length, self.head_dim) |
rkazants marked this conversation as resolved.
Show resolved
Hide resolved
rkazants marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6521,3 +6521,89 @@ def __exit__(self, exc_type, exc_value, traceback): | |
|
||
if is_transformers_version(">=", "4.53"): | ||
Qwen3MoeSparseMoeBlock.forward = self.original_moe_forward | ||
|
||
|
||
rkazants marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
# This patcher class serves for exporting Zamba2 model to OpenVINO IR | ||
class Zamba2ModelPatcher(ModelPatcher): | ||
def __init__( | ||
self, | ||
config: "OnnxConfig", | ||
model: Union["PreTrainedModel", "TFPreTrainedModel"], | ||
model_kwargs: Optional[Dict[str, Any]] = None, | ||
): | ||
super().__init__(config, model, model_kwargs) | ||
|
||
from transformers.models.zamba2.modeling_zamba2 import Zamba2HybridDynamicCache | ||
|
||
class Zamba2HybridDynamicCacheWrap(Zamba2HybridDynamicCache): | ||
def __init__(self, config, batch_size: int, conv_states, ssm_states, key_cache, value_cache): | ||
# Call parent constructor with all required arguments | ||
super().__init__(config=config, batch_size=batch_size) | ||
self.conv_states = conv_states | ||
self.ssm_states = ssm_states | ||
self.key_cache = key_cache | ||
self.value_cache = value_cache | ||
|
||
# the patch is needed to include KV-cache, Conv, and SSM states in the inputs and outputs. | ||
def patched_forward( | ||
input_ids, | ||
attention_mask=None, | ||
position_ids=None, | ||
# cache_position=None, | ||
past_key_values=None, | ||
): | ||
num_hidden_layers = self.real_config._config.num_hidden_layers | ||
use_cache = False | ||
wrapped_cache_params = None | ||
if past_key_values is not None: | ||
use_cache = True | ||
conv_states = [] | ||
ssm_states = [] | ||
key_cache = [] | ||
value_cache = [] | ||
# inputs passed in an order of (key, value, conv_state, ssm_state) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok so length of past_key_values is always 4 * num_hidden_layers is this correct ? if yes would you mind adding ? |
||
for idx in range(num_hidden_layers): | ||
batch_size = past_key_values[4 * idx].size(0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not very important but could be moved outside from loop |
||
key_cache.append(past_key_values[4 * idx]) | ||
value_cache.append(past_key_values[4 * idx + 1]) | ||
conv_states.append(past_key_values[4 * idx + 2]) | ||
ssm_states.append(past_key_values[4 * idx + 3]) | ||
|
||
wrapped_cache_params = Zamba2HybridDynamicCacheWrap( | ||
self.real_config._config, batch_size, conv_states, ssm_states, key_cache, value_cache | ||
) | ||
|
||
causal_lm_output = self.orig_forward( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
position_ids=position_ids, | ||
past_key_values=wrapped_cache_params, | ||
# cache_position=cache_position, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why ? |
||
use_cache=use_cache, | ||
) | ||
outputs = {"logits": causal_lm_output.logits} | ||
|
||
if use_cache: | ||
past_key_values = causal_lm_output.past_key_values | ||
# unwrap Zamba2HybridDynamicCache object | ||
present_key_values = [] | ||
# inputs passed in an order of (key, value, conv_state, ssm_state) | ||
for idx in range(num_hidden_layers): | ||
present_key_values.append(past_key_values.key_cache[idx]) | ||
present_key_values.append(past_key_values.value_cache[idx]) | ||
present_key_values.append(past_key_values.conv_states[idx]) | ||
present_key_values.append(past_key_values.ssm_states[idx]) | ||
|
||
outputs["present_key_values"] = present_key_values | ||
|
||
return outputs | ||
|
||
self.patched_forward = patched_forward | ||
|
||
def __enter__(self): | ||
super().__enter__() | ||
|
||
def __exit__(self, exc_type, exc_value, traceback): | ||
super().__exit__(exc_type, exc_value, traceback) | ||
rkazants marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
Uh oh!
There was an error while loading. Please reload this page.