Skip to content
Open
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
283403f
[OpenVINO] Support Zamba2 by OpenVINO
rkazants Jun 20, 2025
e6ef129
Merge remote-tracking branch 'upstream/main' into support_zamba2_ov
rkazants Jul 27, 2025
f535012
Apply suggestions from code review
rkazants Jul 27, 2025
112af9a
Apply suggestions from code review
rkazants Jul 27, 2025
708f52c
Apply suggestions from code review
rkazants Jul 27, 2025
f956e00
Apply suggestions from code review
rkazants Jul 27, 2025
ff1dbc6
Apply suggestions from code review
rkazants Jul 27, 2025
32cfb33
Apply suggestions from code review
rkazants Jul 27, 2025
6169f62
Apply suggestions from code review
rkazants Jul 27, 2025
78b21de
Apply suggestions from code review
rkazants Jul 27, 2025
191a3f4
Apply suggestions from code review
rkazants Jul 27, 2025
018d81a
Revert changes in notebooks/openvino/stable_diffusion_hybrid_quantiza…
rkazants Jul 27, 2025
7be4c4b
Add tests
rkazants Jul 28, 2025
c6ef767
Fix formatting
rkazants Jul 28, 2025
3b97ea5
Merge remote-tracking branch 'upstream/main' into support_zamba2_ov
rkazants Jul 31, 2025
ff470f7
Re-implement exporting Zamba2 model
rkazants Jul 31, 2025
c906220
Fix export_cli_int8 test
rkazants Jul 31, 2025
196827e
Merge remote-tracking branch 'upstream/main' into support_zamba2_ov
rkazants Oct 9, 2025
34a4ee3
Apply suggestion from @rkazants
rkazants Oct 9, 2025
f094f78
Apply suggestion from @rkazants
rkazants Oct 9, 2025
4c0ffc5
Apply suggestion from @rkazants
rkazants Oct 9, 2025
06ef4e0
Update optimum/exporters/openvino/model_configs.py
rkazants Oct 9, 2025
56bff2e
Update optimum/exporters/openvino/model_configs.py
rkazants Oct 9, 2025
a4e3bd0
Update tests/openvino/test_exporters_cli.py
rkazants Oct 9, 2025
7db344b
Apply suggestion from @rkazants
rkazants Oct 9, 2025
3aca613
Apply suggestion from @rkazants
rkazants Oct 9, 2025
0825f43
Fix formatting
rkazants Oct 9, 2025
b11d517
rkazants Oct 13, 2025
04d1496
^^X
rkazants Oct 13, 2025
bd427b2
Merge remote-tracking branch 'origin/support_zamba2_ov' into support_…
rkazants Oct 13, 2025
ef52983
Introduce hybrid cache for both mamba and zamba2 models
rkazants Oct 13, 2025
74e4da7
Handle hybrid cache
rkazants Oct 14, 2025
f4712b3
Fix model config to set correct dimension for sequence length
rkazants Oct 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/openvino/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ Here is the list of the supported architectures :
- XLM
- XLM-Roberta
- XVERSE
- Zamba2

## [Diffusers](https://huggingface.co/docs/diffusers/index)
- Stable Diffusion
Expand Down
99 changes: 99 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@
QwenModelPatcher,
SanaTextEncoderModelPatcher,
XverseModelPatcher,
Zamba2ModelPatcher,
)


Expand Down Expand Up @@ -4278,3 +4279,101 @@ class GPT2OpenVINOConfig(GPT2OnnxConfig):
)
class VisionEncoderDecoderOpenVINOConfig(VisionEncoderDecoderOnnxConfig):
_MODEL_PATCHER = OVSeq2SeqModelPatcher


class Zamba2DummyInputGenerator(DummyInputGenerator):
"""
Generates dummy past_key_values inputs for Zamba2 architectures.
"""

SUPPORTED_INPUT_NAMES = ("position_ids", "cache_position", "past_key_values")

def __init__(
self,
task: str,
normalized_config,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
**kwargs,
):
config = normalized_config.config
self.num_key_value_heads = normalized_config.num_key_value_heads
self.intermediate_size = int(config.mamba_expand * config.hidden_size)
self.ssm_state_size = config.mamba_d_state
self.conv_kernel_size = config.mamba_d_conv
self.n_mamba_heads = config.n_mamba_heads
self.num_hidden_layers = config.num_hidden_layers
self.mamba_ngroups = config.mamba_ngroups
self.mamba_d_state = config.mamba_d_state
self.batch_size = batch_size
self.mamba_headdim = config.mamba_headdim
self.head_dim = config.attention_head_dim
self.num_attention_heads = config.num_attention_heads
self.sequence_length = sequence_length

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name == "past_key_values":
past_key_values = []
# generate tuples of (key, value, conv_state, ssm_state)
for i in range(self.num_hidden_layers):
kv_shape = (self.batch_size, self.num_attention_heads, 1, self.head_dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here why not

Suggested change
kv_shape = (self.batch_size, self.num_attention_heads, 1, self.head_dim)
kv_shape = (self.batch_size, self.num_attention_heads, self.sequence_length, self.head_dim)

k = self.random_float_tensor(kv_shape, framework=framework, dtype=float_dtype)
v = self.random_float_tensor(kv_shape, framework=framework, dtype=float_dtype)
past_key_values.append(k)
past_key_values.append(v)
conv_state_shape = (
self.batch_size,
self.intermediate_size + 2 * self.mamba_ngroups * self.mamba_d_state,
self.conv_kernel_size,
)
conv_state = self.random_float_tensor(conv_state_shape, framework=framework, dtype=float_dtype)
past_key_values.append(conv_state)
ssm_state_shape = (self.batch_size, self.n_mamba_heads, self.mamba_headdim, self.ssm_state_size)
ssm_state = self.random_float_tensor(ssm_state_shape, framework=framework, dtype=float_dtype)
past_key_values.append(ssm_state)
return past_key_values

raise ValueError(f"Unsupported input name {input_name}")


@register_in_tasks_manager("zamba2", *["text-generation", "text-generation-with-past"], library_name="transformers")
class Zamba2OpenVINOConfig(LlamaOpenVINOConfig):
PAD_ATTENTION_MASK_TO_PAST = False
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, Zamba2DummyInputGenerator)
DUMMY_PKV_GENERATOR_CLASS = Zamba2DummyInputGenerator
MIN_TRANSFORMERS_VERSION = "4.49.0"

def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

if direction == "inputs":
decoder_sequence_name = "past_sequence_length"
kv_name = "past_key_values"
else:
decoder_sequence_name = "past_sequence_length + sequence_length"
kv_name = "present"

for i in range(self._normalized_config.num_layers):
inputs_or_outputs[f"{kv_name}.key.{i}"] = {0: "batch_size", 1: decoder_sequence_name}
inputs_or_outputs[f"{kv_name}.value.{i}"] = {0: "batch_size", 1: decoder_sequence_name}
# [batch_size, conv_kernel_size - 1, d_model]
inputs_or_outputs[f"{kv_name}.conv_state.{i}"] = {0: "batch_size"}
# [batch_size, d_state, d_model]
inputs_or_outputs[f"{kv_name}.ssm_state.{i}"] = {0: "batch_size"}

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
"position_ids": {0: "batch_size", 1: "sequence_length"},
}
if self.use_past_in_inputs:
self.add_past_key_values(common_inputs, direction="inputs")
return common_inputs

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
):
return Zamba2ModelPatcher(self, model, model_kwargs)
84 changes: 84 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6521,3 +6521,87 @@ def __exit__(self, exc_type, exc_value, traceback):

if is_transformers_version(">=", "4.53"):
Qwen3MoeSparseMoeBlock.forward = self.original_moe_forward


# This patcher class serves for exporting Zamba2 model to OpenVINO IR
class Zamba2ModelPatcher(ModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
super().__init__(config, model, model_kwargs)

from transformers.models.zamba2.modeling_zamba2 import Zamba2HybridDynamicCache

class Zamba2HybridDynamicCacheWrap(Zamba2HybridDynamicCache):
def __init__(self, config, batch_size: int, conv_states, ssm_states, key_cache, value_cache):
# Call parent constructor with all required arguments
super().__init__(config=config, batch_size=batch_size)
self.conv_states = conv_states
self.ssm_states = ssm_states
self.key_cache = key_cache
self.value_cache = value_cache

# the patch is needed to include KV-cache, Conv, and SSM states in the inputs and outputs.
def patched_forward(
input_ids,
attention_mask=None,
position_ids=None,
# cache_position=None,
past_key_values=None,
):
num_hidden_layers = self.real_config._config.num_hidden_layers
use_cache = False
wrapped_cache_params = None
if past_key_values is not None:
use_cache = True
conv_states = []
ssm_states = []
key_cache = []
value_cache = []
# inputs passed in an order of (key, value, conv_state, ssm_state)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok so length of past_key_values is always 4 * num_hidden_layers is this correct ? if yes would you mind adding ?

for idx in range(num_hidden_layers):
batch_size = past_key_values[4 * idx].size(0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not very important but could be moved outside from loop

key_cache.append(past_key_values[4 * idx])
value_cache.append(past_key_values[4 * idx + 1])
conv_states.append(past_key_values[4 * idx + 2])
ssm_states.append(past_key_values[4 * idx + 3])

wrapped_cache_params = Zamba2HybridDynamicCacheWrap(
self.real_config._config, batch_size, conv_states, ssm_states, key_cache, value_cache
)

causal_lm_output = self.orig_forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=wrapped_cache_params,
# cache_position=cache_position,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why ?

use_cache=use_cache,
)
outputs = {"logits": causal_lm_output.logits}

if use_cache:
past_key_values = causal_lm_output.past_key_values
# unwrap Zamba2HybridDynamicCache object
present_key_values = []
# inputs passed in an order of (key, value, conv_state, ssm_state)
for idx in range(num_hidden_layers):
present_key_values.append(past_key_values.key_cache[idx])
present_key_values.append(past_key_values.value_cache[idx])
present_key_values.append(past_key_values.conv_states[idx])
present_key_values.append(past_key_values.ssm_states[idx])

outputs["present_key_values"] = present_key_values

return outputs

self.patched_forward = patched_forward

def __enter__(self):
super().__enter__()

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
55 changes: 43 additions & 12 deletions optimum/exporters/openvino/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,23 +269,54 @@ def insert_state_for_nodes(model: ov.Model, nodes):
model.add_sinks([assign])


def patch_stateful_ssm(ov_model: ov.Model):
cache_input_names = [
key_name for key in ov_model.inputs for key_name in key.get_names() if "cache_params.past" in key_name
]
cache_output_names = [
key_name for key in ov_model.outputs for key_name in key.get_names() if "cache_params.present" in key_name
]
def patch_stateful_ssm(config: PretrainedConfig, ov_model: ov.Model):
from openvino._offline_transformations import apply_make_stateful_transformation

if not cache_input_names or not cache_output_names:
return
def get_kv_ssm_tensor_names(ssm_prefix_names: list, kv_prefix_names: list, ov_tensors):
# return tensor names of model inputs/outputs tensors with KV and SSM states
kv_names = []
ssm_names = []
other_names = []
for ov_tensor in ov_tensors:
ov_tensor_names = ov_tensor.get_names()
is_kv_or_ssm = False
for ov_tensor_name in ov_tensor_names:
if any(prefix in ov_tensor_name for prefix in ssm_prefix_names):
ssm_names.append(ov_tensor_name)
is_kv_or_ssm = True
break
elif any(prefix in ov_tensor_name for prefix in kv_prefix_names):
kv_names.append(ov_tensor_name)
is_kv_or_ssm = True
break
if not is_kv_or_ssm:
other_names.append(ov_tensor_name)
return kv_names, ssm_names, other_names

ssm_prefix_input_names = ["cache_params.past", "past_key_values.conv_state", "past_key_values.ssm_state"]
kv_prefix_input_names = ["past_key_values.key", "past_key_values.value"]
kv_input_names, ssm_input_names, other_input_names = get_kv_ssm_tensor_names(
ssm_prefix_input_names, kv_prefix_input_names, ov_model.inputs
)
not_kv_inputs = ssm_input_names + other_input_names

ssm_prefix_output_names = ["cache_params.present", "present.conv_state", "present.ssm_state"]
kv_prefix_output_names = ["present.key", "present.value"]
kv_output_names, ssm_output_names, _ = get_kv_ssm_tensor_names(
ssm_prefix_output_names, kv_prefix_output_names, ov_model.outputs
)

batch_dim = 0

from openvino._offline_transformations import apply_make_stateful_transformation
# hybrid models can contain transformer blocks as well
# so KV tensors must be handled properly
if kv_input_names is not None and len(kv_input_names) > 0:
fuse_cache_reorder(ov_model, not_kv_inputs, kv_input_names, batch_dim)
num_attention_heads = config.num_attention_heads
make_stateful(ov_model, not_kv_inputs, kv_input_names, kv_output_names, batch_dim, num_attention_heads, None)

input_output_map = {}
for cache_name_pair in zip(cache_input_names, cache_output_names):
for cache_name_pair in zip(ssm_input_names, ssm_output_names):
input_output_map[cache_name_pair[0]] = cache_name_pair[1]

apply_make_stateful_transformation(ov_model, input_output_map)
Expand All @@ -296,7 +327,7 @@ def patch_stateful(config: PretrainedConfig, ov_model: ov.Model):
if config.is_encoder_decoder and model_has_input_output_name(ov_model, "encoder_hidden_states"):
return patch_stateful_encoder_decoder(config, ov_model)
if config.model_type in SSM_MODELS:
return patch_stateful_ssm(ov_model)
return patch_stateful_ssm(config, ov_model)
return patch_stateful_decoder(config, ov_model)


Expand Down
2 changes: 1 addition & 1 deletion optimum/exporters/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def get_submodels(model):
"minicpmo",
]

SSM_MODELS = ["mamba", "falcon_mamba"]
SSM_MODELS = ["mamba", "falcon_mamba", "zamba2"]


def save_config(config, save_dir):
Expand Down
3 changes: 3 additions & 0 deletions tests/openvino/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ class ExportModelTest(unittest.TestCase):
"ltx-video": OVLTXPipeline,
}

if is_transformers_version(">=", "4.48"):
SUPPORTED_ARCHITECTURES.update({"zamba2": OVModelForCausalLM})

EXPECTED_DIFFUSERS_SCALE_FACTORS = {
"stable-diffusion-xl": {"vae_encoder": "128.0", "vae_decoder": "128.0"},
"stable-diffusion-3": {"text_encoder_3": "8.0"},
Expand Down
7 changes: 7 additions & 0 deletions tests/openvino/test_exporters_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ class OVCLIExportTestCase(unittest.TestCase):
("zero-shot-image-classification", "clip"),
]

if is_transformers_version(">=", "4.49"):
SUPPORTED_ARCHITECTURES.extend(
[
("text-generation-with-past", "zamba2"),
]
)
EXPECTED_NUMBER_OF_TOKENIZER_MODELS = {
"gpt2": 2 if is_tokenizers_version("<", "0.20") or is_openvino_version(">=", "2024.5") else 0,
"t5": 0 if is_openvino_version("<", "2025.1") else 2, # 2025.1 brings support for unigram tokenizers
Expand All @@ -129,6 +135,7 @@ class OVCLIExportTestCase(unittest.TestCase):
"mamba": 2,
"falcon-mamba": 2,
"qwen3": 2,
"zamba2": 2,
}

TOKENIZER_CHAT_TEMPLATE_TESTS_MODELS = {
Expand Down
2 changes: 2 additions & 0 deletions tests/openvino/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@
"sana": "katuni4ka/tiny-random-sana",
"sana-sprint": "katuni4ka/tiny-random-sana-sprint",
"ltx-video": "katuni4ka/tiny-random-ltx-video",
"zamba2": "rkazants/tiny-zamba2",
}


Expand Down Expand Up @@ -334,6 +335,7 @@
"vision_embeddings_model": 8,
"resampler_model": 6,
},
"zamba2": {"model": 392},
}

TEST_IMAGE_URL = "http://images.cocodataset.org/val2017/000000039769.jpg"
Expand Down
Loading