Skip to content

Commit eac3893

Browse files
committed
Applied comments.
1 parent 942112a commit eac3893

4 files changed

Lines changed: 8 additions & 9 deletions

File tree

optimum/exporters/openvino/model_configs.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414

1515
import enum
1616
import logging
17+
import math
1718
from copy import deepcopy
1819
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
1920

21+
import torch
2022
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
2123

2224
from optimum.exporters.onnx.config import OnnxConfig, TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
@@ -4370,10 +4372,6 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
43704372
dtype=float_dtype,
43714373
)
43724374
if input_name == "image_position_ids":
4373-
import math
4374-
4375-
import torch
4376-
43774375
# Create position ids as a grid. The patch count = h_patches * w_patches
43784376
# where both are divisible by pooling_kernel_size for correct pooling.
43794377
k = self.pooling_kernel_size

optimum/intel/openvino/modeling_decoder.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,19 @@
3333
from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput
3434
from transformers.utils.hub import PushToHubMixin
3535

36+
from ..utils.import_utils import compare_versions, is_transformers_version
3637

37-
try:
38+
39+
if is_transformers_version("<", "5.5"):
3840
from transformers.models.mamba.modeling_mamba import MambaCache
39-
except ImportError:
41+
else:
4042
MambaCache = object
4143

4244
from optimum.utils.normalized_config import NormalizedConfigManager
4345

4446
from ...exporters.openvino import ensure_stateful_is_available, main_export, patch_stateful
4547
from ...exporters.openvino.stateful import model_has_state
4648
from ...exporters.openvino.utils import SSM_MODELS
47-
from ..utils.import_utils import compare_versions
4849
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS
4950
from .configuration import (
5051
OVConfig,

tests/openvino/test_seq2seq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -787,7 +787,7 @@ def compare_outputs(inputs, ov_model, transformers_model, generation_config):
787787
with torch.no_grad():
788788
transformers_outputs = transformers_model(**transformers_inputs)
789789
self.assertTrue(
790-
torch.allclose(ov_outputs.logits, transformers_outputs.logits.to(torch.float32), atol=4e-3),
790+
torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=4e-3),
791791
f"Max abs diff {(torch.abs(ov_outputs.logits - transformers_outputs.logits).max())}",
792792
)
793793

tests/openvino/utils_tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@
383383
"text_embeddings_per_layer_model": 1,
384384
},
385385
"gemma4_moe": {
386-
"lm_model": 44,
386+
"lm_model": 48,
387387
"text_embeddings_model": 1,
388388
"vision_embeddings_model": 10,
389389
"text_embeddings_per_layer_model": 0,

0 commit comments

Comments
 (0)