diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 0624624a77..af0586cdce 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -176,6 +176,8 @@ MiniCPMVImageEmbeddingsModelPatcher, MiniCPMVResamplerModelPatcher, MistralModelPatcher, + Mistral3ImageEmbeddingModelPatcher, + Mistral3LanguageModelPatcher, MixtralModelPatcher, MPTModelPatcher, OVDecoderModelPatcher, @@ -287,6 +289,10 @@ def init_model_configs(): "transformers", "AutoModelForImageTextToText", ) + TasksManager._CUSTOM_CLASSES[("pt", "mistral3", "image-text-to-text")] = ( + "transformers", + "Mistral3ForConditionalGeneration", + ) if is_diffusers_available() and "fill" not in TasksManager._DIFFUSERS_TASKS_TO_MODEL_LOADERS: TasksManager._DIFFUSERS_TASKS_TO_MODEL_LOADERS["fill"] = "FluxFillPipeline" @@ -4171,6 +4177,101 @@ def with_behavior( return super().with_behavior(behavior) +@register_in_tasks_manager("mistral3", *["image-text-to-text"], library_name="transformers") +class Mistral3OpenVINOConfig(BaseVLMOpenVINOConfig): + MIN_TRANSFORMERS_VERSION = "5.4.0" + + def __init__( + self, + config: "PretrainedConfig", + task: str = "feature-extraction", + int_dtype: str = "int64", + float_dtype: str = "fp32", + behavior: VLMConfigBehavior = VLMConfigBehavior.VISION_EMBEDDINGS, + preprocessors: Optional[List[Any]] = None, + **kwargs, + ): + super().__init__( + config=config, + task=task, + int_dtype=int_dtype, + float_dtype=float_dtype, + preprocessors=preprocessors, + ) + self._orig_config = config + if self._behavior == VLMConfigBehavior.VISION_EMBEDDINGS and hasattr(config, "vision_config"): + self._config = config.vision_config + self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config) + + def with_behavior( + self, + behavior: Union[str, VLMConfigBehavior], + ): + if isinstance(behavior, str) and not isinstance(behavior, VLMConfigBehavior): + behavior = VLMConfigBehavior(behavior) + + if behavior == VLMConfigBehavior.TEXT_EMBEDDINGS: + model_type = self._orig_config.text_config.model_type + if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: + model_type = "mistral" + return get_vlm_text_embeddings_config( + model_type, self._orig_config.text_config, self.int_dtype, self.float_dtype + ) + + if behavior == VLMConfigBehavior.LANGUAGE: + model_type = self._orig_config.text_config.model_type + if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: + model_type = "mistral" + return get_vlm_text_generation_config( + model_type, + self._orig_config.text_config, + self.int_dtype, + self.float_dtype, + model_patcher=Mistral3LanguageModelPatcher, + ) + + if behavior == VLMConfigBehavior.VISION_EMBEDDINGS: + return self.__class__( + self._orig_config, + task=self.task, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + behavior=behavior, + preprocessors=self._preprocessors, + ) + + def get_model_for_behavior(self, model, behavior: Union[str, VLMConfigBehavior]): + if isinstance(behavior, str) and not isinstance(behavior, VLMConfigBehavior): + behavior = VLMConfigBehavior(behavior) + + if behavior == VLMConfigBehavior.LANGUAGE: + return model # full model needed for tied lm_head + + if behavior == VLMConfigBehavior.VISION_EMBEDDINGS: + return model # full model, patcher replaces forward + + if behavior == VLMConfigBehavior.TEXT_EMBEDDINGS: + text_embedding = model.get_input_embeddings() + text_embedding.config = model.model.language_model.config + return text_embedding + + def patch_model_for_export(self, model, model_kwargs=None): + model_kwargs = model_kwargs or {} + if self._behavior == VLMConfigBehavior.VISION_EMBEDDINGS: + return Mistral3ImageEmbeddingModelPatcher(self, model, model_kwargs) + return super().patch_model_for_export(model, model_kwargs) + + @property + def outputs(self): + if self._behavior == VLMConfigBehavior.VISION_EMBEDDINGS: + return {"last_hidden_state": {0: "num_patches"}} + return super().outputs + + def generate_dummy_inputs(self, framework="pt", **kwargs): + if self._behavior == VLMConfigBehavior.VISION_EMBEDDINGS: + kwargs["batch_size"] = 1 + return super().generate_dummy_inputs(framework, **kwargs) + class DummyVisionPositionIdsInputGenerator(DummyVisionInputGenerator): SUPPORTED_INPUT_NAMES = ("patch_attention_mask", "patch_position_ids") diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 32dd2d6c6d..26f6254ca6 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -227,6 +227,15 @@ def patch_cos_sin_cached_fp32(model): def eager_mask_without_vmap(*args, **kwargs) -> Optional[torch.Tensor]: kwargs.pop("allow_is_causal_skip", None) dtype = kwargs.get("dtype", torch.float32) + # Handle transformers >= 5.4 API change: q_length kwarg instead of cache_position positional arg + if "q_length" in kwargs and "cache_position" not in kwargs: + q_length = kwargs.pop("q_length") + q_offset = kwargs.pop("q_offset", 0) + device = kwargs.get("device", "cpu") + kwargs["cache_position"] = torch.arange(q_offset, q_offset + q_length, device=device) + # Remove kwargs not accepted by sdpa_mask_without_vmap + for key in ["allow_is_bidirectional_skip", "allow_torch_fix", "use_vmap", "config", "dtype"]: + kwargs.pop(key, None) mask = sdpa_mask_without_vmap(*args, allow_is_causal_skip=False, **kwargs) # we use torch.finfo(torch.float16).min instead torch.finfo(dtype).min to avoid an overflow but not # sure this is the right way to handle this, we are basically pretending that -65,504 is -inf @@ -8319,3 +8328,160 @@ def __exit__(self, exc_type, exc_value, traceback): sparse_moe_block = decoder_layer.mlp decoder_layer.mlp.forward = decoder_layer.mlp._orig_forward del sparse_moe_block.down_projs, sparse_moe_block.gate_projs, sparse_moe_block.up_projs + + +def _mistral3_vision_embed_forward(self, pixel_values): + """ + Inline vision pipeline (vision_tower + multi_modal_projector) for Mistral3. + All dimensions are derived from tensor .shape to stay dynamic during OpenVINO tracing. + """ + vision_tower = self.model.vision_tower + projector = self.model.multi_modal_projector + config = self.config + + # Step 1: Patch convolution + target_dtype = vision_tower.patch_conv.weight.dtype + patch_embeds = vision_tower.patch_conv(pixel_values.to(dtype=target_dtype)) + # patch_embeds: (batch, hidden, h_patches, w_patches) + h_patches = patch_embeds.shape[2] + w_patches = patch_embeds.shape[3] + d = patch_embeds.shape[1] + + # Step 2: Flatten and normalize (single image, batch=1) + patch_embeds = patch_embeds[0].flatten(1).T.unsqueeze(0) # (1, h*w, d) + patch_embeds = vision_tower.ln_pre(patch_embeds) + + # Step 3: Position embeddings - derive from tensor shapes + max_width = config.vision_config.image_size // config.vision_config.patch_size + h_idx = torch.arange(h_patches, device=pixel_values.device) + w_idx = torch.arange(w_patches, device=pixel_values.device) + mesh_h, mesh_w = torch.meshgrid(h_idx, w_idx, indexing="ij") + position_ids = (mesh_h.reshape(-1) * max_width + mesh_w.reshape(-1)) + + position_embeddings = vision_tower.patch_positional_embedding(patch_embeds, position_ids) + + # Step 4: Build block attention mask for non-flash attention + seq_len = patch_embeds.shape[1] + causal_mask = torch.zeros((seq_len, seq_len), dtype=patch_embeds.dtype, device=patch_embeds.device) + attention_mask = causal_mask[None, None, :, :].expand(1, 1, -1, -1) + + # Step 5: Transformer layers + transformer_out = vision_tower.transformer( + patch_embeds, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + ) + + # Step 6: Select feature layer + selected = transformer_out.last_hidden_state # (1, num_patches, hidden) + + # Step 7: Apply projector (norm + patch_merger + MLP) + image_features = projector.norm(selected.squeeze(0)) # (num_patches, hidden) + + # Patch merger - unfold + merge + spatial_merge = config.spatial_merge_size + patch_size = config.vision_config.patch_size + image_grid = image_features.view(h_patches, w_patches, d).permute(2, 0, 1).unsqueeze(0) + grid = torch.nn.functional.unfold(image_grid, kernel_size=spatial_merge, stride=spatial_merge) + grid = grid.view(d * spatial_merge ** 2, -1).t() + image_features = projector.patch_merger.merging_layer(grid) + + # MLP projection + image_features = projector.linear_1(image_features) + image_features = projector.act(image_features) + image_features = projector.linear_2(image_features) + + return image_features + + +class Mistral3ImageEmbeddingModelPatcher(ModelPatcher): + def __init__(self, config, model, model_kwargs): + model.__orig_forward = model.forward + model.forward = types.MethodType(_mistral3_vision_embed_forward, model) + super().__init__(config, model, model_kwargs) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + self._model.forward = self._model.__orig_forward + + +class Mistral3LanguageModelPatcher(OVDecoderModelPatcher): + """Patcher for Mistral3 language model: fixes sliding_window=None and injects cache_position.""" + + def __init__( + self, + config: "OnnxConfig", + model: "PreTrainedModel", + model_kwargs: Optional[Dict[str, Any]] = None, + ): + # Override forward to inject cache_position (required by attention mask functions) + def forward(self, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache=True): + # Compute cache_position from past_key_values shape + if isinstance(past_key_values, (tuple, list)) and len(past_key_values) > 0: + if isinstance(past_key_values[0], (tuple, list)): + past_seen_tokens = past_key_values[0][0].shape[-2] + else: + past_seen_tokens = past_key_values[0].shape[-2] + else: + past_seen_tokens = 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + result = self.__orig_forward( + input_ids=None, + attention_mask=attention_mask, + position_ids=position_ids, + cache_position=cache_position, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + ) + return result + + model.__orig_forward = model.forward + model.forward = types.MethodType(forward, model) + + super().__init__(config, model, model_kwargs) + + def _get_language_model(self): + """Find the language model within the Mistral3 composite model.""" + if hasattr(self._model, "model") and hasattr(self._model.model, "language_model"): + return self._model.model.language_model + if hasattr(self._model, "language_model"): + return self._model.language_model + return self._model + + def __enter__(self): + # Fix sliding_window=None in the language model config + lang_model = self._get_language_model() + cfg = lang_model.config if hasattr(lang_model, "config") else self._model.config + self._orig_sliding_window = getattr(cfg, "sliding_window", None) + if self._orig_sliding_window is None: + cfg.sliding_window = getattr(cfg, "max_position_embeddings", 262144) + + if hasattr(lang_model, "layers"): + for layer in lang_model.layers: + if hasattr(layer.self_attn, "sliding_window"): + layer.self_attn._orig_sliding_window = layer.self_attn.sliding_window + if layer.self_attn.sliding_window is None: + layer.self_attn.sliding_window = cfg.sliding_window + + super().__enter__() + return self + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + + # Restore forward + self._model.forward = self._model.__orig_forward + + # Restore sliding_window + lang_model = self._get_language_model() + cfg = lang_model.config if hasattr(lang_model, "config") else self._model.config + cfg.sliding_window = self._orig_sliding_window + if hasattr(lang_model, "layers"): + for layer in lang_model.layers: + if hasattr(layer.self_attn, "_orig_sliding_window"): + layer.self_attn.sliding_window = layer.self_attn._orig_sliding_window + del layer.self_attn._orig_sliding_window diff --git a/optimum/exporters/openvino/utils.py b/optimum/exporters/openvino/utils.py index af2f1edaba..38f215621f 100644 --- a/optimum/exporters/openvino/utils.py +++ b/optimum/exporters/openvino/utils.py @@ -303,6 +303,7 @@ def get_submodels(model): "phi4_multimodal", "llama4", "minicpmo", + "mistral3", ] SSM_MODELS = ["mamba", "falcon_mamba", "zamba2", "lfm2", "granitemoehybrid", "qwen3_next"] diff --git a/optimum/intel/openvino/modeling_visual_language.py b/optimum/intel/openvino/modeling_visual_language.py index beb7b974eb..373c8b68fe 100644 --- a/optimum/intel/openvino/modeling_visual_language.py +++ b/optimum/intel/openvino/modeling_visual_language.py @@ -4802,6 +4802,61 @@ def preprocess_inputs( return inputs +class _OVMistral3ForCausalLM(OVModelForVisualCausalLM): + def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs): + if input_ids is not None and input_ids.shape[1] == 1: + return None + if pixel_values is not None and pixel_values.dtype != torch.float32: + pixel_values = pixel_values.to(torch.float32) + return self.vision_embeddings(pixel_values).last_hidden_state + + def merge_vision_text_embeddings( + self, vision_embeds, inputs_embeds, input_ids=None, attention_mask=None, position_ids=None, **kwargs + ): + image_features = torch.from_numpy(vision_embeds) if isinstance(vision_embeds, np.ndarray) else vision_embeds + inputs_embeds = torch.from_numpy(inputs_embeds) if isinstance(inputs_embeds, np.ndarray) else inputs_embeds + + image_token_id = getattr(self.config, "image_token_index", getattr(self.config, "image_token_id", 10)) + special_image_mask = (input_ids == image_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds) + + image_features = image_features.view(-1, image_features.shape[-1]).to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + return inputs_embeds, attention_mask, position_ids + + @staticmethod + def preprocess_inputs( + text: str, + image: Optional["Image"] = None, + processor: Optional[AutoImageProcessor] = None, + tokenizer: Optional[PreTrainedTokenizer] = None, + config: Optional[PretrainedConfig] = None, + video: Optional["VideoInput"] = None, + audio: Optional[np.ndarray] = None, + ): + if processor is None: + raise ValueError("Processor is required.") + if video is not None: + raise ValueError("Video input is not supported") + if audio is not None: + raise ValueError("Audio input is not supported") + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": text}, + ], + } + ] + if image is not None: + conversation[0]["content"].insert(0, {"type": "image"}) + + text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False) + inputs = processor(images=image, text=text_prompt, return_tensors="pt") + return inputs + + MODEL_TYPE_TO_CLS_MAPPING = { "llava": _OVLlavaForCausalLM, "llava_next": _OVLlavaNextForCausalLM, @@ -4824,4 +4879,5 @@ def preprocess_inputs( "llama4": _OVLlama4ForCausalLM, "qwen3_vl": _OVQwen3VLForCausalLM, "minicpmo": _OVMiniCPMOForCausalLM, + "mistral3": _OVMistral3ForCausalLM, }