Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions optimum/commands/export/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,13 @@ def parse_args_openvino(parser: "ArgumentParser"):
"reduces quantization error. Valid only when activations quantization is enabled."
),
)
optional_group.add_argument(
"--eagle3",
action="store_true",
help=(
"the original model is a draft model of eagle3 pipeline."
),
)
optional_group.add_argument(
"--model-kwargs",
type=json.loads,
Expand Down Expand Up @@ -576,6 +583,7 @@ def run(self):
library_name=library_name,
variant=self.args.variant,
model_kwargs=self.args.model_kwargs,
eagle3=self.args.eagle3,
# **input_shapes,
)

Expand Down
56 changes: 56 additions & 0 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@
import logging
import operator
import warnings
import json
import os
import importlib.util
from functools import reduce
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union

from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from requests.exceptions import ConnectionError as RequestsConnectionError
from safetensors.torch import save_file
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase, ProcessorMixin
from transformers.utils import is_torch_available

Expand Down Expand Up @@ -105,6 +109,50 @@ def infer_task(
)
return task

def eagle3_config(model_path: str):
config_file = os.path.join(model_path, 'config.json')
# rename the origin config
org_config_file = os.path.join(model_path, 'config_org.json')
os.rename(config_file, org_config_file)

# read config
with open(org_config_file, 'r', encoding='utf-8') as f:
config = json.load(f)
# modify config
if 'model_type' in config.keys():
org_type = config['model_type']
if 'eagle3' not in org_type:
config['model_type'] = org_type + 'eagle3'
moduler_name = 'optimum.exporters.openvino.model_patcher'
spec = importlib.util.find_spec(moduler_name)
if spec and spec.origin:
moduler_path = os.path.dirname(spec.origin)
config['auto_map'] = {
"AutoConfig": moduler_path + "--model_patcher.LlamaEagle3Config",
"AutoModel": moduler_path + "--model_patcher.LlamaEagle3Model",
"AutoModelForCausalLM": moduler_path + "--model_patcher.LlamaEagle3ForCausalLM"
}
# write new config.json
with open(config_file, 'w', encoding='utf-8') as f:
json.dump(config, f, ensure_ascii=False, indent=2)

def extract_d2t(model_path: str, output_path: str):
load_model_path=os.path.join(model_path, "pytorch_model.bin")
output_path = os.path.join(output_path, "eagle3.safetensors")
target_keys = ['d2t', 't2d']
if os.path.exists(load_model_path):
state_dict = torch.load(load_model_path, map_location=torch.device('cpu'))
extracted = {k: state_dict[k] for k in target_keys if k in state_dict.keys()}
# save output file
save_file(extracted, output_path)

def restore_config(model_path: str, ov_path: str):
# restore the origin config
config_file = os.path.join(model_path, 'config.json')
org_config_file = os.path.join(model_path, 'config_org.json')
os.rename(org_config_file, config_file)
if os.path.exists(ov_path):
extract_d2t(model_path, ov_path)

def main_export(
model_name_or_path: str,
Expand All @@ -130,6 +178,7 @@ def main_export(
library_name: Optional[str] = None,
model_loading_kwargs: Optional[Dict[str, Any]] = None,
variant: Optional[str] = None,
eagle3: bool = False,
**kwargs_shapes,
):
"""
Expand Down Expand Up @@ -187,6 +236,8 @@ def main_export(
especially useful when exporting a custom architecture that needs to split the ONNX (e.g. encoder-decoder). If unspecified with custom models, optimum will try to use the default submodels used for the given task, with no guarantee of success.
stateful (`bool`, defaults to `True`):
Produce stateful model where all kv-cache inputs and outputs are hidden in the model and are not exposed as model inputs and outputs. Applicable only for decoder models.
eagle3 (`bool`, defaults to `False`):
This is needed by eagle3 draft models.
**kwargs_shapes (`Dict`):
Shapes to use during inference. This argument allows to override the default shapes used during the ONNX export.

Expand Down Expand Up @@ -251,6 +302,9 @@ def main_export(
dtype = getattr(torch, dtype) if dtype != "auto" else dtype

if library_name == "transformers":
if eagle3:
eagle3_config(model_name_or_path)

config = AutoConfig.from_pretrained(
model_name_or_path,
subfolder=subfolder,
Expand Down Expand Up @@ -539,6 +593,8 @@ class StoreAttr(object):
torch.cuda.is_available = orig_cuda_check
if do_gptq_patching:
GPTQQuantizer.post_init_model = orig_post_init_model
if eagle3 and library_name == "transformers":
restore_config(model_name_or_path, output)


def maybe_convert_tokenizers(library_name: str, output: Path, model=None, preprocessors=None, task=None):
Expand Down
13 changes: 13 additions & 0 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ def _save_model(

runtime_options = config.runtime_options if hasattr(config, "runtime_options") else {}
model = _add_runtime_options_to_rt_info(model, runtime_options)

if 'eagle3' in config._config.model_type:
model = _add_eagle3_mode_to_rt_info(model)
save_model(model, path, compress_to_fp16)
del model
gc.collect()
Expand Down Expand Up @@ -831,6 +834,16 @@ def _add_runtime_options_to_rt_info(model: Model, options: Dict):

return model

def _add_eagle3_mode_to_rt_info(model: Model):
"""
Add eagle3 mode
"""
try:
model.set_rt_info("True", ["eagle3_mode"])
except Exception:
pass

return model

def _add_version_info_to_model(model: Model, library_name: Optional[str] = None):
"""
Expand Down
46 changes: 46 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4278,3 +4278,49 @@ class GPT2OpenVINOConfig(GPT2OnnxConfig):
)
class VisionEncoderDecoderOpenVINOConfig(VisionEncoderDecoderOnnxConfig):
_MODEL_PATCHER = OVSeq2SeqModelPatcher


class EAGLE3DummyGenerator(DummyInputGenerator):
"""
Generates dummy hidden_states inputs.
"""

SUPPORTED_INPUT_NAMES = ("hidden_states",)

def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
**kwargs,
):
self.batch_size = batch_size
self.sequence_length = sequence_length
self.hidden_size = normalized_config.hidden_size

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
shape = (
self.batch_size,
self.sequence_length,
self.hidden_size*3,
)
return self.random_float_tensor(shape, framework=framework, dtype=float_dtype)

@register_in_tasks_manager( "llamaeagle3",*["text-generation","text-generation-with-past"],library_name="transformers")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what kind of model can we convert with such addition? I am asking because original model has a different model type llama3.
Can you convert only local copy with modified model type? I am not sure that it is capable to convert original eagle3 llama model.
Also, implemented solution looks not scalable for other eagle3 models such as https://huggingface.co/nvidia/gpt-oss-120b-Eagle3

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have verified conversion and GENAI pipeline with yuhuili/EAGLE3-LLaMA3.1-Instruct-8B and Tengyunw/qwen3_8b_eagle3 locally, and AngelSlim/Qwen3-1.7B_eagle3 will be added to GENAI repo test openvinotoolkit/openvino.genai#2740. Checked the list on EAGLE3 github repo, most of them are llama type, they can be converted in theory or with limit update, can we merge this PR firstly and leave the verification per OpenVINO base model support progress and customer requirements?

AngelSlim/Qwen3-14B_eagle3/config.json:  "model_type": "qwen3",
AngelSlim/Qwen3-a3B_eagle3/config.json:  "model_type": "llama",
AngelSlim/Qwen3-32B_eagle3/config.json:  "model_type": "llama",
AngelSlim/Qwen3-4B_eagle3/config.json:  "model_type": "llama",
AngelSlim/Qwen3-8B_eagle3/config.json:  "model_type": "llama",
AngelSlim/Qwen3-1.7B_eagle3/config.json:  "model_type": "llama",
linglingdan/Eagle3_for_MiniCPM4/config.json:  "model_type": "llama", 
lmsys/EAGLE3-gpt-oss-120b-bf16/config.json:  "model_type": "llama",
lmsys/sglang-EAGLE3-Llama-4-Scout-17B-16E-Instruct-v1/config.json:  "model_type": "llama",
lmsys/Qwen3-235B-A22B-EAGLE3/config.json:  "model_type": "llama",
lmsys/sglang-EAGLE3-Llama-4-Maverick-17B-128E-Instruct-v1/config.json:  "model_type": "llama",
nvidia/gpt-oss-120b-Eagle3/config.json:  "model_type": "llama",
nvidia/Qwen3-235B-A22B-Eagle3/config.json:  "model_type": "llama",
nvidia/Llama-4-Maverick-17B-128E-Eagle3 ??,
Tengyunw/qwen3_30b_moe_eagle3/config.json:  "model_type": "llama",
Tengyunw/qwen3_8b_eagle3/config.json:  "model_type": "llama",
wantsleep/OLMoE_1B_7B_Eagle3/config.json:  "model_type": "olmoe",
yuhuili/EAGLE3-LLaMA3.3-Instruct-70B/config.json:  "model_type": "llama",
yuhuili/EAGLE3-DeepSeek-R1-Distill-LLaMA-8B/config.json:  "model_type": "llama",
yuhuili/EAGLE3-LLaMA3.1-Instruct-8B/config.json:  "model_type": "llama",
yuhuili/EAGLE3-Vicuna1.3-13B/config.json:  "model_type": "llama",
Zjcxy-SmartAI/Eagle3-Qwen3-4B-Instruct-2507-zh/config.json:  "model_type": "llama",

Copy link
Collaborator

@rkazants rkazants Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't we use original model type? For this, it relies on the different model type that seems to be modified manually by you. That is not how it should work. These changes should allow to convert the original model. where does llamaeagle3 model type come from?
Does it mean that user should re-create all eagle3 model and modify its model type, etc.?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rkazants Discussed with Fang, WIP to avoid config.json modification by passing model_type="llamaeagle3" to AutoConfig.from_pretrained

Copy link
Contributor

@peterchen-intel peterchen-intel Oct 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't we use original model type? llama modeling in transformers can't support eagle3 draft model, the modeling for eagle3 draft model is from https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py. Current PR should support the conversion of eagle3 draft model with model_type: "llama" in config.json

class LlamaEagle3OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig):
DEFAULT_ONNX_OPSET = 14 # Llama now uses F.scaled_dot_product_attention by default for torch>=2.1.1.
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator, EAGLE3DummyGenerator)
DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = super().inputs
common_inputs["hidden_states"] = {0: "batch_size", 1: "sequence_length", 2: "hidden_size"}
return common_inputs

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return OVDecoderModelPatcher(self, model, model_kwargs=model_kwargs)
Loading