Skip to content

Commit 910cbaf

Browse files
authored
Implicitly add olive version to saved onnx model proto (microsoft#2183)
## Implicitly add olive version to saved onnx model proto ## Checklist before requesting a review - [ ] Add unit tests for this change. - [x] Make sure all tests can pass. - [ ] Update documents if necessary. - [x] Lint and apply fixes to your code by running `lintrunner -a` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. ## (Optional) Issue link
1 parent c620dfd commit 910cbaf

File tree

5 files changed

+56
-5
lines changed

5 files changed

+56
-5
lines changed

olive/common/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,10 @@ def hardlink_copy_dir(src_dir, dst_dir, **kwargs):
497497
copy_dir(src_dir, dst_dir, copy_function=hardlink_copy_file, dirs_exist_ok=True, **kwargs)
498498

499499

500+
def is_hardlink(path: Union[str, Path]) -> bool:
501+
return Path(path).stat().st_nlink > 1
502+
503+
500504
def set_tempdir(tempdir: str = None):
501505
"""Set the root directory for tempfiles.
502506

olive/engine/packaging/packaging_generator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919
from olive.hardware.accelerator import AcceleratorSpec
2020
from olive.model import ONNXModelHandler
21+
from olive.passes.onnx.common import add_version_metadata_to_model_proto
2122
from olive.resource_path import ResourceType, create_resource_path
2223

2324
logger = logging.getLogger(__name__)
@@ -265,6 +266,8 @@ def _generate_onnx_mlflow_model(model_dir: Path, inference_config: dict):
265266
# MLFlow will save models with default config save_as_external_data=True
266267
# https://github.com/mlflow/mlflow/blob/1d6eaaa65dca18688d9d1efa3b8b96e25801b4e9/mlflow/onnx.py#L175
267268
# There will be an alphanumeric file generated in the same folder as the model file
269+
# Add olive version to metadata
270+
add_version_metadata_to_model_proto(model_proto)
268271
mlflow.onnx.save_model(
269272
model_proto,
270273
mlflow_model_path,

olive/passes/onnx/common.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,27 @@ def get_external_data_config() -> dict[str, PassConfigParam]:
7676
}
7777

7878

79+
def add_version_metadata_to_model_proto(model: onnx.ModelProto) -> onnx.ModelProto:
80+
olive_version = None
81+
try:
82+
import olive
83+
84+
olive_version = getattr(olive, "__version__", "unknown")
85+
except Exception:
86+
olive_version = "unknown"
87+
88+
for md in model.metadata_props:
89+
if md.key == "olive_version":
90+
md.value = olive_version
91+
return model
92+
93+
md = model.metadata_props.add()
94+
md.key = "olive_version"
95+
md.value = olive_version
96+
97+
return model
98+
99+
79100
def model_proto_to_file(
80101
model: onnx.ModelProto,
81102
output_path: Union[str, Path],
@@ -119,6 +140,8 @@ def model_proto_to_file(
119140
)
120141

121142
if not save_as_external_data:
143+
# Add olive version to metadata
144+
add_version_metadata_to_model_proto(model)
122145
# save model
123146
onnx.save_model(model, str(output_path))
124147
return False
@@ -136,6 +159,8 @@ def model_proto_to_file(
136159
if any(output_dir.iterdir()):
137160
raise RuntimeError(f"Output directory ({output_dir}) for external data is not empty.")
138161

162+
# Add olive version to metadata
163+
add_version_metadata_to_model_proto(model)
139164
# save model
140165
onnx.save_model(
141166
model,
@@ -392,7 +417,7 @@ def resave_model(
392417
if not external_file_names:
393418
if force_external_data:
394419
# save the model with single external data file
395-
model_proto_to_file(onnx.load(model_path), new_model_path, {"save_as_external_data": True})
420+
model_proto_to_file(onnx.load(model_path), new_model_path, save_as_external_data=True)
396421
return True
397422

398423
# no external data, so we can just copy the model
@@ -401,7 +426,7 @@ def resave_model(
401426

402427
if len(external_file_names) > 1:
403428
# save the model with single external data file
404-
model_proto_to_file(onnx.load(model_path), new_model_path, {"save_as_external_data": True})
429+
model_proto_to_file(onnx.load(model_path), new_model_path, save_as_external_data=True)
405430
return True
406431

407432
external_file_path = str(model_path.parent / external_file_names[0])

olive/passes/onnx/static_llm.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010
from olive.hardware.accelerator import AcceleratorSpec
1111
from olive.model import CompositeModelHandler, ONNXModelHandler
1212
from olive.passes import Pass
13-
from olive.passes.onnx.common import fix_dim_params, process_llm_pipeline, resave_model
13+
from olive.passes.onnx.common import (
14+
add_version_metadata_to_model_proto,
15+
fix_dim_params,
16+
process_llm_pipeline,
17+
resave_model,
18+
)
1419
from olive.passes.onnx.onnx_dag import OnnxDAG
1520
from olive.passes.pass_config import BasePassConfig, PassConfigParam
1621

@@ -119,6 +124,8 @@ def process_context_iterator(component_models, llm_pipeline, output_dir):
119124

120125
# save the model with fixed shapes
121126
component_model_path = output_dir / f"{new_component_name}.onnx"
127+
# Add olive version to metadata
128+
add_version_metadata_to_model_proto(component_proto)
122129
onnx.save_model(component_proto, component_model_path)
123130
new_groups[key][new_component_name] = ONNXModelHandler(
124131
model_path=output_dir, onnx_file_name=component_model_path.name

test/passes/onnx/test_common.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,13 @@
66
import onnx
77
import pytest
88

9+
from olive.common.utils import is_hardlink
910
from olive.passes.olive_pass import create_pass_from_dict
10-
from olive.passes.onnx.common import model_proto_to_olive_model, resave_model
11+
from olive.passes.onnx.common import (
12+
add_version_metadata_to_model_proto,
13+
model_proto_to_olive_model,
14+
resave_model,
15+
)
1116
from olive.passes.onnx.conversion import OnnxConversion
1217
from test.utils import ONNX_MODEL_PATH, get_hf_model
1318

@@ -47,4 +52,11 @@ def test_resave_model(has_external_data, tmp_path):
4752
assert resave_path.exists()
4853
if has_external_data:
4954
assert (resave_path.parent / "resave.onnx.data").exists()
50-
assert onnx.load(resave_path) == onnx.load(input_model.model_path)
55+
56+
input_model = onnx.load(input_model.model_path)
57+
resaved_model = onnx.load(resave_path)
58+
59+
if not is_hardlink(resave_path):
60+
input_model = add_version_metadata_to_model_proto(input_model)
61+
62+
assert resaved_model == input_model

0 commit comments

Comments
 (0)