Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@
MiniCPMVImageEmbeddingsModelPatcher,
MiniCPMVResamplerModelPatcher,
MistralModelPatcher,
Mistral3ImageEmbeddingModelPatcher,
Mistral3LanguageModelPatcher,
MixtralModelPatcher,
MPTModelPatcher,
OVDecoderModelPatcher,
Expand Down Expand Up @@ -287,6 +289,10 @@ def init_model_configs():
"transformers",
"AutoModelForImageTextToText",
)
TasksManager._CUSTOM_CLASSES[("pt", "mistral3", "image-text-to-text")] = (
"transformers",
"Mistral3ForConditionalGeneration",
)

if is_diffusers_available() and "fill" not in TasksManager._DIFFUSERS_TASKS_TO_MODEL_LOADERS:
TasksManager._DIFFUSERS_TASKS_TO_MODEL_LOADERS["fill"] = "FluxFillPipeline"
Expand Down Expand Up @@ -4171,6 +4177,101 @@ def with_behavior(
return super().with_behavior(behavior)


@register_in_tasks_manager("mistral3", *["image-text-to-text"], library_name="transformers")
class Mistral3OpenVINOConfig(BaseVLMOpenVINOConfig):
MIN_TRANSFORMERS_VERSION = "5.4.0"

Comment on lines +4180 to +4183
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR introduces a new model type (mistral3) with custom export/inference behavior, but there are no corresponding OpenVINO tests added. Given the existing coverage for other VLMs in tests/openvino/*, please add at least a smoke test that exports and runs a short generate pass for mistral3 (or a tiny/random checkpoint) to prevent regressions in the patchers and behavior routing.

Copilot uses AI. Check for mistakes.
def __init__(
self,
config: "PretrainedConfig",
task: str = "feature-extraction",
int_dtype: str = "int64",
float_dtype: str = "fp32",
behavior: VLMConfigBehavior = VLMConfigBehavior.VISION_EMBEDDINGS,
preprocessors: Optional[List[Any]] = None,
**kwargs,
):
super().__init__(
config=config,
task=task,
int_dtype=int_dtype,
float_dtype=float_dtype,
preprocessors=preprocessors,
)
self._orig_config = config
if self._behavior == VLMConfigBehavior.VISION_EMBEDDINGS and hasattr(config, "vision_config"):
self._config = config.vision_config
self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config)

def with_behavior(
self,
behavior: Union[str, VLMConfigBehavior],
):
if isinstance(behavior, str) and not isinstance(behavior, VLMConfigBehavior):
behavior = VLMConfigBehavior(behavior)

if behavior == VLMConfigBehavior.TEXT_EMBEDDINGS:
model_type = self._orig_config.text_config.model_type
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
model_type = "mistral"
return get_vlm_text_embeddings_config(
model_type, self._orig_config.text_config, self.int_dtype, self.float_dtype
)

if behavior == VLMConfigBehavior.LANGUAGE:
model_type = self._orig_config.text_config.model_type
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
model_type = "mistral"
return get_vlm_text_generation_config(
model_type,
self._orig_config.text_config,
self.int_dtype,
self.float_dtype,
model_patcher=Mistral3LanguageModelPatcher,
)

if behavior == VLMConfigBehavior.VISION_EMBEDDINGS:
return self.__class__(
self._orig_config,
task=self.task,
int_dtype=self.int_dtype,
float_dtype=self.float_dtype,
behavior=behavior,
preprocessors=self._preprocessors,
)

def get_model_for_behavior(self, model, behavior: Union[str, VLMConfigBehavior]):
if isinstance(behavior, str) and not isinstance(behavior, VLMConfigBehavior):
behavior = VLMConfigBehavior(behavior)

if behavior == VLMConfigBehavior.LANGUAGE:
return model # full model needed for tied lm_head

if behavior == VLMConfigBehavior.VISION_EMBEDDINGS:
return model # full model, patcher replaces forward

if behavior == VLMConfigBehavior.TEXT_EMBEDDINGS:
text_embedding = model.get_input_embeddings()
text_embedding.config = model.model.language_model.config
return text_embedding

def patch_model_for_export(self, model, model_kwargs=None):
model_kwargs = model_kwargs or {}
if self._behavior == VLMConfigBehavior.VISION_EMBEDDINGS:
return Mistral3ImageEmbeddingModelPatcher(self, model, model_kwargs)
return super().patch_model_for_export(model, model_kwargs)

@property
def outputs(self):
if self._behavior == VLMConfigBehavior.VISION_EMBEDDINGS:
return {"last_hidden_state": {0: "num_patches"}}
return super().outputs

def generate_dummy_inputs(self, framework="pt", **kwargs):
if self._behavior == VLMConfigBehavior.VISION_EMBEDDINGS:
kwargs["batch_size"] = 1
return super().generate_dummy_inputs(framework, **kwargs)

class DummyVisionPositionIdsInputGenerator(DummyVisionInputGenerator):
SUPPORTED_INPUT_NAMES = ("patch_attention_mask", "patch_position_ids")

Expand Down
166 changes: 166 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,15 @@ def patch_cos_sin_cached_fp32(model):
def eager_mask_without_vmap(*args, **kwargs) -> Optional[torch.Tensor]:
kwargs.pop("allow_is_causal_skip", None)
dtype = kwargs.get("dtype", torch.float32)
# Handle transformers >= 5.4 API change: q_length kwarg instead of cache_position positional arg
if "q_length" in kwargs and "cache_position" not in kwargs:
q_length = kwargs.pop("q_length")
q_offset = kwargs.pop("q_offset", 0)
device = kwargs.get("device", "cpu")
kwargs["cache_position"] = torch.arange(q_offset, q_offset + q_length, device=device)
# Remove kwargs not accepted by sdpa_mask_without_vmap
for key in ["allow_is_bidirectional_skip", "allow_torch_fix", "use_vmap", "config", "dtype"]:
kwargs.pop(key, None)
mask = sdpa_mask_without_vmap(*args, allow_is_causal_skip=False, **kwargs)
# we use torch.finfo(torch.float16).min instead torch.finfo(dtype).min to avoid an overflow but not
# sure this is the right way to handle this, we are basically pretending that -65,504 is -inf
Expand Down Expand Up @@ -8319,3 +8328,160 @@ def __exit__(self, exc_type, exc_value, traceback):
sparse_moe_block = decoder_layer.mlp
decoder_layer.mlp.forward = decoder_layer.mlp._orig_forward
del sparse_moe_block.down_projs, sparse_moe_block.gate_projs, sparse_moe_block.up_projs


def _mistral3_vision_embed_forward(self, pixel_values):
"""
Inline vision pipeline (vision_tower + multi_modal_projector) for Mistral3.
All dimensions are derived from tensor .shape to stay dynamic during OpenVINO tracing.
"""
vision_tower = self.model.vision_tower
projector = self.model.multi_modal_projector
config = self.config

# Step 1: Patch convolution
target_dtype = vision_tower.patch_conv.weight.dtype
patch_embeds = vision_tower.patch_conv(pixel_values.to(dtype=target_dtype))
# patch_embeds: (batch, hidden, h_patches, w_patches)
h_patches = patch_embeds.shape[2]
w_patches = patch_embeds.shape[3]
d = patch_embeds.shape[1]

# Step 2: Flatten and normalize (single image, batch=1)
patch_embeds = patch_embeds[0].flatten(1).T.unsqueeze(0) # (1, h*w, d)
patch_embeds = vision_tower.ln_pre(patch_embeds)
Comment on lines +8350 to +8352
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_mistral3_vision_embed_forward hard-codes batch size 1 by indexing patch_embeds[0] (and later selected.squeeze(0)), which will silently drop all but the first image if pixel_values has batch>1. Either make the implementation batch-safe (operate on the full batch) or add an explicit check that pixel_values.shape[0] == 1 and raise a clear error so callers don't get incorrect results.

Copilot uses AI. Check for mistakes.

# Step 3: Position embeddings - derive from tensor shapes
max_width = config.vision_config.image_size // config.vision_config.patch_size
h_idx = torch.arange(h_patches, device=pixel_values.device)
w_idx = torch.arange(w_patches, device=pixel_values.device)
mesh_h, mesh_w = torch.meshgrid(h_idx, w_idx, indexing="ij")
position_ids = (mesh_h.reshape(-1) * max_width + mesh_w.reshape(-1))

position_embeddings = vision_tower.patch_positional_embedding(patch_embeds, position_ids)

# Step 4: Build block attention mask for non-flash attention
seq_len = patch_embeds.shape[1]
causal_mask = torch.zeros((seq_len, seq_len), dtype=patch_embeds.dtype, device=patch_embeds.device)
attention_mask = causal_mask[None, None, :, :].expand(1, 1, -1, -1)

# Step 5: Transformer layers
transformer_out = vision_tower.transformer(
patch_embeds,
attention_mask=attention_mask,
position_embeddings=position_embeddings,
)

# Step 6: Select feature layer
selected = transformer_out.last_hidden_state # (1, num_patches, hidden)

# Step 7: Apply projector (norm + patch_merger + MLP)
image_features = projector.norm(selected.squeeze(0)) # (num_patches, hidden)

# Patch merger - unfold + merge
spatial_merge = config.spatial_merge_size
patch_size = config.vision_config.patch_size
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

patch_size = config.vision_config.patch_size is assigned but never used. Please remove it to avoid dead code and keep this patcher easier to maintain.

Suggested change
patch_size = config.vision_config.patch_size

Copilot uses AI. Check for mistakes.
image_grid = image_features.view(h_patches, w_patches, d).permute(2, 0, 1).unsqueeze(0)
grid = torch.nn.functional.unfold(image_grid, kernel_size=spatial_merge, stride=spatial_merge)
grid = grid.view(d * spatial_merge ** 2, -1).t()
image_features = projector.patch_merger.merging_layer(grid)

# MLP projection
image_features = projector.linear_1(image_features)
image_features = projector.act(image_features)
image_features = projector.linear_2(image_features)

return image_features


class Mistral3ImageEmbeddingModelPatcher(ModelPatcher):
def __init__(self, config, model, model_kwargs):
model.__orig_forward = model.forward
model.forward = types.MethodType(_mistral3_vision_embed_forward, model)
super().__init__(config, model, model_kwargs)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
self._model.forward = self._model.__orig_forward


class Mistral3LanguageModelPatcher(OVDecoderModelPatcher):
"""Patcher for Mistral3 language model: fixes sliding_window=None and injects cache_position."""

def __init__(
self,
config: "OnnxConfig",
model: "PreTrainedModel",
model_kwargs: Optional[Dict[str, Any]] = None,
):
# Override forward to inject cache_position (required by attention mask functions)
def forward(self, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache=True):
# Compute cache_position from past_key_values shape
if isinstance(past_key_values, (tuple, list)) and len(past_key_values) > 0:
if isinstance(past_key_values[0], (tuple, list)):
past_seen_tokens = past_key_values[0][0].shape[-2]
else:
past_seen_tokens = past_key_values[0].shape[-2]
else:
past_seen_tokens = 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)

result = self.__orig_forward(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
cache_position=cache_position,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
)
return result

model.__orig_forward = model.forward
model.forward = types.MethodType(forward, model)

super().__init__(config, model, model_kwargs)

def _get_language_model(self):
"""Find the language model within the Mistral3 composite model."""
if hasattr(self._model, "model") and hasattr(self._model.model, "language_model"):
return self._model.model.language_model
if hasattr(self._model, "language_model"):
return self._model.language_model
return self._model

def __enter__(self):
# Fix sliding_window=None in the language model config
lang_model = self._get_language_model()
cfg = lang_model.config if hasattr(lang_model, "config") else self._model.config
self._orig_sliding_window = getattr(cfg, "sliding_window", None)
if self._orig_sliding_window is None:
cfg.sliding_window = getattr(cfg, "max_position_embeddings", 262144)

if hasattr(lang_model, "layers"):
for layer in lang_model.layers:
if hasattr(layer.self_attn, "sliding_window"):
layer.self_attn._orig_sliding_window = layer.self_attn.sliding_window
if layer.self_attn.sliding_window is None:
layer.self_attn.sliding_window = cfg.sliding_window

super().__enter__()
return self

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)

# Restore forward
self._model.forward = self._model.__orig_forward

# Restore sliding_window
lang_model = self._get_language_model()
cfg = lang_model.config if hasattr(lang_model, "config") else self._model.config
cfg.sliding_window = self._orig_sliding_window
if hasattr(lang_model, "layers"):
for layer in lang_model.layers:
if hasattr(layer.self_attn, "_orig_sliding_window"):
layer.self_attn.sliding_window = layer.self_attn._orig_sliding_window
del layer.self_attn._orig_sliding_window
1 change: 1 addition & 0 deletions optimum/exporters/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def get_submodels(model):
"phi4_multimodal",
"llama4",
"minicpmo",
"mistral3",
]

SSM_MODELS = ["mamba", "falcon_mamba", "zamba2", "lfm2", "granitemoehybrid", "qwen3_next"]
Expand Down
56 changes: 56 additions & 0 deletions optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -4802,6 +4802,61 @@ def preprocess_inputs(
return inputs


class _OVMistral3ForCausalLM(OVModelForVisualCausalLM):
def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs):
if input_ids is not None and input_ids.shape[1] == 1:
return None
if pixel_values is not None and pixel_values.dtype != torch.float32:
pixel_values = pixel_values.to(torch.float32)
return self.vision_embeddings(pixel_values).last_hidden_state

def merge_vision_text_embeddings(
self, vision_embeds, inputs_embeds, input_ids=None, attention_mask=None, position_ids=None, **kwargs
):
image_features = torch.from_numpy(vision_embeds) if isinstance(vision_embeds, np.ndarray) else vision_embeds
inputs_embeds = torch.from_numpy(inputs_embeds) if isinstance(inputs_embeds, np.ndarray) else inputs_embeds

image_token_id = getattr(self.config, "image_token_index", getattr(self.config, "image_token_id", 10))
special_image_mask = (input_ids == image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds)

image_features = image_features.view(-1, image_features.shape[-1]).to(inputs_embeds.device, inputs_embeds.dtype)
Comment on lines +4820 to +4823
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge_vision_text_embeddings relies on masked_scatter to fail if the number of image-token positions doesn't match the number of vision embeddings, which can produce a fairly opaque runtime error. Consider adding an explicit count check (like the _OVLlama4ForCausalLM implementation just above) and raising a clear ValueError when there is a mismatch; also ensure the boolean mask is on inputs_embeds.device before scattering.

Suggested change
special_image_mask = (input_ids == image_token_id).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(inputs_embeds)
image_features = image_features.view(-1, image_features.shape[-1]).to(inputs_embeds.device, inputs_embeds.dtype)
special_image_mask = (input_ids == image_token_id).unsqueeze(-1).to(inputs_embeds.device)
image_features = image_features.view(-1, image_features.shape[-1]).to(inputs_embeds.device, inputs_embeds.dtype)
num_image_tokens = special_image_mask[..., 0].sum().item()
num_image_features = image_features.shape[0]
if num_image_tokens != num_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {num_image_tokens}, features {num_image_features}"
)
special_image_mask = special_image_mask.expand_as(inputs_embeds)

Copilot uses AI. Check for mistakes.
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)

return inputs_embeds, attention_mask, position_ids

@staticmethod
def preprocess_inputs(
text: str,
image: Optional["Image"] = None,
processor: Optional[AutoImageProcessor] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
config: Optional[PretrainedConfig] = None,
video: Optional["VideoInput"] = None,
audio: Optional[np.ndarray] = None,
):
if processor is None:
raise ValueError("Processor is required.")
if video is not None:
raise ValueError("Video input is not supported")
if audio is not None:
raise ValueError("Audio input is not supported")
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": text},
],
}
]
if image is not None:
conversation[0]["content"].insert(0, {"type": "image"})

text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
inputs = processor(images=image, text=text_prompt, return_tensors="pt")
return inputs


MODEL_TYPE_TO_CLS_MAPPING = {
"llava": _OVLlavaForCausalLM,
"llava_next": _OVLlavaNextForCausalLM,
Expand All @@ -4824,4 +4879,5 @@ def preprocess_inputs(
"llama4": _OVLlama4ForCausalLM,
"qwen3_vl": _OVQwen3VLForCausalLM,
"minicpmo": _OVMiniCPMOForCausalLM,
"mistral3": _OVMistral3ForCausalLM,
}