Skip to content
51 changes: 51 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4490,3 +4490,54 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
)

return dummy_inputs


class HunyuanDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not use instead MistralDummyPastKeyValuesGenerator and set instead normalized_config.head_dim

def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
random_batch_size_range: Optional[Tuple[int, int]] = None,
random_sequence_length_range: Optional[Tuple[int, int]] = None,
**kwargs,
):
super().__init__(
task=task,
normalized_config=normalized_config,
batch_size=batch_size,
sequence_length=sequence_length,
random_batch_size_range=random_batch_size_range,
random_sequence_length_range=random_sequence_length_range,
)
self.num_key_value_heads = normalized_config.num_key_value_heads
self.head_dim = normalized_config.attention_head_dim

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
shape = (
self.batch_size,
self.num_key_value_heads,
self.sequence_length,
self.head_dim,
)
return [
(
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
self.random_float_tensor(shape, framework=framework, dtype=float_dtype),
)
for _ in range(self.num_layers)
]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would you mind adding a test as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I will add it once the release version of transformers support this model.


@register_in_tasks_manager("hunyuan_v1_dense", *["text-generation", "text-generation-with-past"], library_name="transformers")
class HunyuanOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
MIN_TRANSFORMERS_VERSION = "4.54.0"

DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, HunyuanDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = HunyuanDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return UpdateCausalMaskModelPatcher(self, model, model_kwargs=model_kwargs)
2 changes: 1 addition & 1 deletion optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6753,4 +6753,4 @@ def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
setattr(self._model, self.orig_forward_name, self.orig_forward)
for layer in self._model.backbone.layers:
layer.mixer.forward = layer.mixer._orig_forward
layer.mixer.forward = layer.mixer._orig_forward