Skip to content

Commit a3775ee

Browse files
hamelphiThe fusion_surrogates Authors
authored andcommitted
Add qlknn_7_11 onnx model export
PiperOrigin-RevId: 797791491
1 parent 56e5b6d commit a3775ee

File tree

4 files changed

+73
-0
lines changed

4 files changed

+73
-0
lines changed
295 KB
Binary file not shown.
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright 2025 DeepMind Technologies Limited.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for the ONNX QLKNN_7_11 model."""
16+
17+
from absl.testing import absltest
18+
import chex
19+
from fusion_surrogates.qlknn import qlknn_model
20+
from fusion_surrogates.qlknn.models import registry
21+
from jaxonnxruntime import backend
22+
import numpy as np
23+
import onnx
24+
25+
26+
class Qlknn711OnnxTest(absltest.TestCase):
27+
28+
def test_qlknn_7_11_onnx_model(self):
29+
"""Tests that the ONNX models outputs match jax model outputs."""
30+
with open(
31+
registry.ONNX_MODELS["qlknn_7_11_v1"], "rb"
32+
) as f:
33+
onnx_model = onnx.load(f.name)
34+
35+
jax_model = qlknn_model.QLKNNModel.load_model_from_name("qlknn_7_11_v1")
36+
37+
batch_size = 100
38+
test_input = np.random.randn(batch_size, jax_model.num_inputs).astype(
39+
np.float32
40+
)
41+
42+
# Running the ONNX model using jaxonnxruntime.
43+
jax_model_from_onnx = backend.prepare(onnx_model)
44+
onnx_flat_output = jax_model_from_onnx.run([test_input])
45+
46+
# Recovering the flux names from the ONNX graph
47+
output_names = [node.name for node in onnx_model.graph.output]
48+
49+
# Reconstructing the flux dictionary.
50+
onnx_dict_output = dict(
51+
(k, v) for k, v in zip(output_names, onnx_flat_output)
52+
)
53+
54+
# Running the original JAX model.
55+
jax_output = jax_model.predict(test_input)
56+
57+
# Checking that the output names match the expected output keys.
58+
self.assertEmpty(set(onnx_dict_output.keys()) ^ set(jax_output.keys()))
59+
# Checking that the output values match.
60+
chex.assert_trees_all_close(onnx_dict_output, jax_output, atol=1e-06)
61+
62+
63+
if __name__ == "__main__":
64+
absltest.main()

fusion_surrogates/qlknn/models/registry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,7 @@
2121
MODELS = immutabledict.immutabledict({
2222
'qlknn_7_11_v1': f'{pathlib.Path(__file__).parent}/qlknn_7_11.qlknn',
2323
})
24+
25+
ONNX_MODELS = immutabledict.immutabledict({
26+
'qlknn_7_11_v1': f'{pathlib.Path(__file__).parent}/qlknn_7_11.onnx',
27+
})

pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,15 @@ dependencies = [
2929

3030
[project.optional-dependencies]
3131
testing = [
32+
"chex",
33+
"jaxonnxruntime",
34+
"onnx",
3235
"pytest",
3336
"pytest-xdist",
3437
"pylint>=2.6.0",
3538
"pyink",
39+
"pyyaml",
40+
"torch",
3641
]
3742
tglfnnukaea = [
3843
"torch",

0 commit comments

Comments
 (0)