Skip to content

Commit db6d12b

Browse files
authored
Enhance ONNX model creation by adding configurable opset version parameter (#1037)
1 parent 7467d22 commit db6d12b

2 files changed

Lines changed: 6 additions & 4 deletions

File tree

lib/model/nns/onnx/onnx_exporter.js

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,20 @@ import * as layers from './layer/index.js'
1010
export default class ONNXExporter {
1111
/**
1212
* Create onnx model proto.
13+
* @param {object} [config] object
14+
* @param {object} [config.opset] Config for operator set
15+
* @param {number} [config.opset.version] Version of operator set
1316
* @returns {onnx.ModelProto} Model proto
1417
*/
15-
static createONNXModel() {
18+
static createONNXModel({ opset } = {}) {
1619
const model = new onnx.ModelProto()
1720
model.setProducerName('ai-on-browser/data-analysis-models')
1821
model.setProducerVersion('0.24.0')
1922
model.setIrVersion(9)
2023

2124
const opsetImport = new onnx.OperatorSetIdProto()
2225
opsetImport.setDomain('')
23-
opsetImport.setVersion(19)
26+
opsetImport.setVersion(opset?.version ?? 19)
2427
model.addOpsetImport(opsetImport)
2528

2629
const graph = new onnx.GraphProto()

tests/lib/model/nns/onnx/layer/gelu.test.js

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,7 @@ describe('export', () => {
5252

5353
describe('opset version 20', () => {
5454
test.each(['x', ['x']])('input %j', input => {
55-
const model = ONNXExporter.createONNXModel()
56-
model.getOpsetImportList()[0].setVersion(20)
55+
const model = ONNXExporter.createONNXModel({ opset: { version: 20 } })
5756
gelu.export(model, { type: 'gelu', input })
5857
const nodes = model.getGraph().getNodeList()
5958
expect(nodes).toHaveLength(1)

0 commit comments

Comments
 (0)