Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
283403f
[OpenVINO] Support Zamba2 by OpenVINO
rkazants Jun 20, 2025
e6ef129
Merge remote-tracking branch 'upstream/main' into support_zamba2_ov
rkazants Jul 27, 2025
f535012
Apply suggestions from code review
rkazants Jul 27, 2025
112af9a
Apply suggestions from code review
rkazants Jul 27, 2025
708f52c
Apply suggestions from code review
rkazants Jul 27, 2025
f956e00
Apply suggestions from code review
rkazants Jul 27, 2025
ff1dbc6
Apply suggestions from code review
rkazants Jul 27, 2025
32cfb33
Apply suggestions from code review
rkazants Jul 27, 2025
6169f62
Apply suggestions from code review
rkazants Jul 27, 2025
78b21de
Apply suggestions from code review
rkazants Jul 27, 2025
191a3f4
Apply suggestions from code review
rkazants Jul 27, 2025
018d81a
Revert changes in notebooks/openvino/stable_diffusion_hybrid_quantiza…
rkazants Jul 27, 2025
7be4c4b
Add tests
rkazants Jul 28, 2025
c6ef767
Fix formatting
rkazants Jul 28, 2025
3b97ea5
Merge remote-tracking branch 'upstream/main' into support_zamba2_ov
rkazants Jul 31, 2025
ff470f7
Re-implement exporting Zamba2 model
rkazants Jul 31, 2025
c906220
Fix export_cli_int8 test
rkazants Jul 31, 2025
196827e
Merge remote-tracking branch 'upstream/main' into support_zamba2_ov
rkazants Oct 9, 2025
34a4ee3
Apply suggestion from @rkazants
rkazants Oct 9, 2025
f094f78
Apply suggestion from @rkazants
rkazants Oct 9, 2025
4c0ffc5
Apply suggestion from @rkazants
rkazants Oct 9, 2025
06ef4e0
Update optimum/exporters/openvino/model_configs.py
rkazants Oct 9, 2025
56bff2e
Update optimum/exporters/openvino/model_configs.py
rkazants Oct 9, 2025
a4e3bd0
Update tests/openvino/test_exporters_cli.py
rkazants Oct 9, 2025
7db344b
Apply suggestion from @rkazants
rkazants Oct 9, 2025
3aca613
Apply suggestion from @rkazants
rkazants Oct 9, 2025
0825f43
Fix formatting
rkazants Oct 9, 2025
b11d517
rkazants Oct 13, 2025
04d1496
^^X
rkazants Oct 13, 2025
bd427b2
Merge remote-tracking branch 'origin/support_zamba2_ov' into support_…
rkazants Oct 13, 2025
ef52983
Introduce hybrid cache for both mamba and zamba2 models
rkazants Oct 13, 2025
74e4da7
Handle hybrid cache
rkazants Oct 14, 2025
f4712b3
Fix model config to set correct dimension for sequence length
rkazants Oct 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/openvino/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ Here is the list of the supported architectures :
- XLM
- XLM-Roberta
- XVERSE
- Zamba2

## [Diffusers](https://huggingface.co/docs/diffusers/index)
- Stable Diffusion
Expand Down
109 changes: 109 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@
StatefulSeq2SeqDecoderPatcher,
UpdateCausalMaskModelPatcher,
XverseModelPatcher,
Zamba2ModelPatcher,
)


Expand Down Expand Up @@ -4366,3 +4367,111 @@ def patch_model_for_export(
if self._behavior != VLMConfigBehavior.VISION_EMBEDDINGS:
return super().patch_model_for_export(model, model_kwargs)
return Llama4ImageEmbeddingsModelPatcher(self, model, model_kwargs)


class Zamba2DummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
random_batch_size_range: Optional[Tuple[int, int]] = None,
random_sequence_length_range: Optional[Tuple[int, int]] = None,
**kwargs,
):
super().__init__(
task=task,
normalized_config=normalized_config,
batch_size=batch_size,
sequence_length=sequence_length,
random_batch_size_range=random_batch_size_range,
random_sequence_length_range=random_sequence_length_range,
)
config = normalized_config.config
self.num_key_value_heads = normalized_config.num_key_value_heads
self.intermediate_size = int(config.mamba_expand * config.hidden_size)
self.ssm_state_size = config.mamba_d_state
self.conv_kernel_size = config.mamba_d_conv
self.n_mamba_heads = config.n_mamba_heads
self.num_hidden_layers = config.num_hidden_layers
self.mamba_ngroups = config.mamba_ngroups
self.mamba_d_state = config.mamba_d_state
self.batch_size = batch_size
self.mamba_headdim = config.mamba_headdim
self.head_dim = config.attention_head_dim
self.num_attention_heads = config.num_attention_heads

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
kv_states_cache = []
# generate tuples of (key, value, conv_state, ssm_state)
for i in range(self.num_hidden_layers):
kv_shape = (self.batch_size, self.num_attention_heads, 1, self.head_dim)
k = self.random_float_tensor(kv_shape, framework=framework, dtype=float_dtype)
v = self.random_float_tensor(kv_shape, framework=framework, dtype=float_dtype)
kv_states_cache.append(k)
kv_states_cache.append(v)
conv_state_shape = (
self.batch_size,
self.intermediate_size + 2 * self.mamba_ngroups * self.mamba_d_state,
self.conv_kernel_size,
)
conv_state = self.random_float_tensor(conv_state_shape, framework=framework, dtype=float_dtype)
kv_states_cache.append(conv_state)
ssm_state_shape = (self.batch_size, self.n_mamba_heads, self.mamba_headdim, self.ssm_state_size)
ssm_state = self.random_float_tensor(ssm_state_shape, framework=framework, dtype=float_dtype)
kv_states_cache.append(ssm_state)

return kv_states_cache


@register_in_tasks_manager("zamba2", *["text-generation", "text-generation-with-past"], library_name="transformers")
class Zamba2OpenVINOConfig(LlamaOpenVINOConfig):
PAD_ATTENTION_MASK_TO_PAST = False
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, Zamba2DummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = Zamba2DummyPastKeyValuesGenerator

def add_past_key_values_states(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
name = "past_key_values"
else:
decoder_sequence_name = "past_sequence_length + 1"
name = "present"

# generate tuples of (key, value, conv_state, ssm_state)
for i in range(self._normalized_config.num_layers):
inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch_size", 1: decoder_sequence_name}
inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch_size", 1: decoder_sequence_name}
inputs_or_outputs[f"{name}.{i}.conv_state"] = {0: "batch_size"}
inputs_or_outputs[f"{name}.{i}.ssm_state"] = {0: "batch_size"}

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {"input_ids": {0: "batch_size", 1: "sequence_length"}}
common_inputs["attention_mask"] = {0: "batch_size", 1: "sequence_length"}
common_inputs["position_ids"] = {0: "batch_size", 1: "sequence_length"}
self.add_past_key_values_states(common_inputs, direction="inputs")
return common_inputs

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs = {}
common_outputs["logits"] = {}

# outputs in an order of (key, value, conv_state, ssm_state)
for idx in range(self._normalized_config.num_layers):
common_outputs["key_cache.present.{}".format(idx)] = {}
common_outputs["value_cache.present.{}".format(idx)] = {}
common_outputs["conv_states.present.{}".format(idx)] = {}
common_outputs["ssm_states.present.{}".format(idx)] = {}

return common_outputs

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
):
return Zamba2ModelPatcher(self, model, model_kwargs)
80 changes: 80 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6456,3 +6456,83 @@ def __exit__(self, exc_type, exc_value, traceback):
if layer.is_moe_layer:
layer.feed_forward.forward = layer.feed_forward._orig_forward
layer.self_attn.forward = layer.self_attn._orig_forward


class Zamba2ModelPatcher(DecoderModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)

def __enter__(self):
self._model._orig_forward = self._model.forward

# the patch is needed to include KV-cache, Conv, and SSM states in the inputs and outputs.
def zamba2_forward(self, input_ids, attention_mask, position_ids, past_key_values):
from transformers.models.zamba2.modeling_zamba2 import Zamba2HybridDynamicCache

class CustomZamba2HybridDynamicCache(Zamba2HybridDynamicCache):
def __init__(self, config, batch_size: int, conv_states, ssm_states, key_cache, value_cache):
"""
Custom Zamba2 cache object that holds conv, SSM, and KV states.

Args:
conv_states (torch.Tensor): Cached convolution states.
ssm_states (torch.Tensor): Cached state-space model states.
key_cache (torch.Tensor): Cached keys for attention.
value_cache (torch.Tensor): Cached values for attention.
config (PretrainedConfig): Model configuration object.
"""
# Call parent constructor with all required arguments
super().__init__(config=config, batch_size=batch_size)
self.conv_states = conv_states
self.ssm_states = ssm_states
self.key_cache = key_cache
self.value_cache = value_cache

batch_size = 2

num_hidden_layers = self.config.num_hidden_layers
conv_states = []
ssm_states = []
key_cache = []
value_cache = []
# inputs passed in an order of (key, value, conv_state, ssm_state)
for idx in range(num_hidden_layers):
batch_size = past_key_values[4 * idx].size(0)

key_cache.append(past_key_values[4 * idx])
value_cache.append(past_key_values[4 * idx + 1])
conv_states.append(past_key_values[4 * idx + 2])
ssm_states.append(past_key_values[4 * idx + 3])

past_key_values = CustomZamba2HybridDynamicCache(
self.config, batch_size, conv_states, ssm_states, key_cache, value_cache
)

causal_lm_output = self._orig_forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
)
outputs = {"logits": causal_lm_output.logits}

# outputs in an order of (key, value, conv_state, ssm_state)
num_states = len(past_key_values.conv_states)
for idx in range(num_states):
outputs["key_cache.present.{}".format(idx)] = past_key_values.key_cache[idx]
outputs["value_cache.present.{}".format(idx)] = past_key_values.value_cache[idx]
outputs["conv_states.present.{}".format(idx)] = past_key_values.conv_states[idx]
outputs["ssm_states.present.{}".format(idx)] = past_key_values.ssm_states[idx]

return outputs

self._model.forward = types.MethodType(zamba2_forward, self._model)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model._orig_forward