Skip to content

Commit ea223e4

Browse files
authored
Add ONNX export (#928)
* Add ONNX export * Fix some issues * Fix test * set env.wasm.numThreads to 1
1 parent 827aa3c commit ea223e4

File tree

305 files changed

+16904
-8
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

305 files changed

+16904
-8
lines changed

create_import_list.js

+13
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,19 @@ const createONNXOperatorlist = async () => {
223223
await fs.promises.writeFile(operatorsDir + '/index.js', '// This file is generated automatically.\n' + code)
224224
}
225225

226+
const createONNXLayerlist = async () => {
227+
const layerDir = './lib/model/nns/onnx/layer'
228+
const files = await fs.promises.readdir(layerDir)
229+
let code = ''
230+
for (const file of files) {
231+
if (file !== 'index.js' && file.endsWith('.js')) {
232+
code += `export { default as ${file.slice(0, -3)} } from './${file}'\n`
233+
}
234+
}
235+
await fs.promises.writeFile(layerDir + '/index.js', '// This file is generated automatically.\n' + code)
236+
}
237+
226238
await createLayerlist()
239+
await createONNXLayerlist()
227240
await createONNXOperatorlist()
228241
await createEntrypoint()

lib/model/nns/graph.js

+9
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import Matrix from '../../util/matrix.js'
22
import { NeuralnetworkException } from '../neuralnetwork.js'
33
import Layer from './layer/base.js'
44
import { InputLayer, OutputLayer } from './layer/index.js'
5+
import ONNXExporter from './onnx/onnx_exporter.js'
56
import ONNXImporter from './onnx/onnx_importer.js'
67

78
/**
@@ -130,6 +131,14 @@ export default class ComputationalGraph {
130131
return s + '}'
131132
}
132133

134+
/**
135+
* Returns onnx model
136+
* @returns {Uint8Array} onnx model byte array
137+
*/
138+
toONNX() {
139+
return ONNXExporter.dump(this.toObject())
140+
}
141+
133142
/**
134143
* Add a layer.
135144
* @param {Layer | PlainLayerObject} layer Added layer

lib/model/nns/onnx/layer/abs.js

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import { onnx } from '../onnx_exporter.js'
2+
3+
/**
4+
* Handle abs layer
5+
*/
6+
export default {
7+
/**
8+
* Export to onnx object.
9+
* @param {onnx.ModelProto} model Model object
10+
* @param {import("../../graph").LayerObject & {type: 'abs'}} obj Node object
11+
*/
12+
export(model, obj) {
13+
const input = Array.isArray(obj.input) ? obj.input[0] : obj.input
14+
const node = new onnx.NodeProto()
15+
node.setOpType('Abs')
16+
node.addInput(input)
17+
node.addOutput(obj.name)
18+
19+
const graph = model.getGraph()
20+
graph.addNode(node)
21+
},
22+
}

lib/model/nns/onnx/layer/acos.js

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import { onnx } from '../onnx_exporter.js'
2+
3+
/**
4+
* Handle acos layer
5+
*/
6+
export default {
7+
/**
8+
* Export to onnx object.
9+
* @param {onnx.ModelProto} model Model object
10+
* @param {import("../../graph").LayerObject & {type: 'acos'}} obj Node object
11+
*/
12+
export(model, obj) {
13+
const input = Array.isArray(obj.input) ? obj.input[0] : obj.input
14+
const node = new onnx.NodeProto()
15+
node.setOpType('Acos')
16+
node.addInput(input)
17+
node.addOutput(obj.name)
18+
19+
const graph = model.getGraph()
20+
graph.addNode(node)
21+
},
22+
}

lib/model/nns/onnx/layer/acosh.js

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import { onnx } from '../onnx_exporter.js'
2+
3+
/**
4+
* Handle acosh layer
5+
*/
6+
export default {
7+
/**
8+
* Export to onnx object.
9+
* @param {onnx.ModelProto} model Model object
10+
* @param {import("../../graph").LayerObject & {type: 'acosh'}} obj Node object
11+
*/
12+
export(model, obj) {
13+
const input = Array.isArray(obj.input) ? obj.input[0] : obj.input
14+
const node = new onnx.NodeProto()
15+
node.setOpType('Acosh')
16+
node.addInput(input)
17+
node.addOutput(obj.name)
18+
19+
const graph = model.getGraph()
20+
graph.addNode(node)
21+
},
22+
}

lib/model/nns/onnx/layer/add.js

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import { onnx } from '../onnx_exporter.js'
2+
3+
/**
4+
* Handle add layer
5+
*/
6+
export default {
7+
/**
8+
* Export to onnx object.
9+
* @param {onnx.ModelProto} model Model object
10+
* @param {import("../../graph").LayerObject & {type: 'add'}} obj Node object
11+
*/
12+
export(model, obj) {
13+
if (!Array.isArray(obj.input)) {
14+
throw new Error(`Invalid attribute 'input' value ${obj.input}.`)
15+
}
16+
const node = new onnx.NodeProto()
17+
node.setOpType(obj.input.length === 2 ? 'Add' : 'Sum')
18+
for (const i of obj.input) {
19+
node.addInput(i)
20+
}
21+
node.addOutput(obj.name)
22+
23+
const graph = model.getGraph()
24+
graph.addNode(node)
25+
},
26+
}

lib/model/nns/onnx/layer/and.js

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import { onnx } from '../onnx_exporter.js'
2+
3+
/**
4+
* Handle and layer
5+
*/
6+
export default {
7+
/**
8+
* Export to onnx object.
9+
* @param {onnx.ModelProto} model Model object
10+
* @param {import("../../graph").LayerObject & {type: 'and'}} obj Node object
11+
* @param {{[key: string]: {type: onnx.TensorProto.DataType; size: number[]}}} info Output informatino of other layers
12+
* @returns {{type: onnx.TensorProto.DataType; size: number[]} | undefined} Output information of this layer
13+
*/
14+
export(model, obj, info) {
15+
if (!Array.isArray(obj.input)) {
16+
throw new Error(`Invalid attribute 'input' value ${obj.input}.`)
17+
}
18+
const graph = model.getGraph()
19+
20+
const node = new onnx.NodeProto()
21+
if (obj.input.length === 1) {
22+
node.setOpType('Identity')
23+
node.addInput(obj.input[0])
24+
node.addOutput(obj.name)
25+
graph.addNode(node)
26+
return
27+
}
28+
const boolInputs = []
29+
for (const i of obj.input) {
30+
if (info[i].type === onnx.TensorProto.DataType.BOOL) {
31+
boolInputs.push(i)
32+
} else {
33+
const castnode = new onnx.NodeProto()
34+
castnode.setOpType('Cast')
35+
castnode.addInput(i)
36+
castnode.addOutput(`${obj.name}_${i}_cast`)
37+
const to = new onnx.AttributeProto()
38+
to.setName('to')
39+
to.setType(onnx.AttributeProto.AttributeType.INT)
40+
to.setI(onnx.TensorProto.DataType.BOOL)
41+
castnode.addAttribute(to)
42+
graph.addNode(castnode)
43+
boolInputs.push(`${obj.name}_${i}_cast`)
44+
}
45+
}
46+
let prev_in = boolInputs[0]
47+
for (let i = 1; i < boolInputs.length - 1; i++) {
48+
const node_and = new onnx.NodeProto()
49+
node_and.setOpType('And')
50+
node_and.addInput(prev_in)
51+
node_and.addInput(boolInputs[i])
52+
node_and.addOutput((prev_in = obj.name + `_and_${i - 1}`))
53+
graph.addNode(node_and)
54+
}
55+
56+
node.setOpType('And')
57+
node.addInput(prev_in)
58+
node.addInput(boolInputs.at(-1))
59+
60+
node.addOutput(obj.name)
61+
graph.addNode(node)
62+
return { type: onnx.TensorProto.DataType.BOOL }
63+
},
64+
}

lib/model/nns/onnx/layer/apl.js

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import { onnx } from '../onnx_exporter.js'
2+
3+
/**
4+
* Handle apl layer
5+
*/
6+
export default {
7+
/**
8+
* Export to onnx object.
9+
* @param {onnx.ModelProto} model Model object
10+
* @param {import("../../graph").LayerObject & {type: 'apl'}} obj Node object
11+
*/
12+
export(model, obj) {
13+
const input = Array.isArray(obj.input) ? obj.input[0] : obj.input
14+
const node_base = new onnx.NodeProto()
15+
node_base.setOpType('Relu')
16+
node_base.addInput(input)
17+
node_base.addOutput(obj.name + '_base')
18+
19+
const graph = model.getGraph()
20+
graph.addNode(node_base)
21+
22+
const node_add = new onnx.NodeProto()
23+
node_add.setOpType('Sum')
24+
node_add.addInput(obj.name + '_base')
25+
26+
const s = obj.s ?? 2
27+
const a = Array.isArray(obj.a) ? obj.a : Array(s).fill(obj.a ?? 0.1)
28+
const b = Array.isArray(obj.b) ? obj.b : Array(s).fill(obj.b ?? 0)
29+
for (let i = 0; i < s; i++) {
30+
const tensor_a = new onnx.TensorProto()
31+
tensor_a.setName(obj.name + '_a_' + i)
32+
tensor_a.setDataType(onnx.TensorProto.DataType.FLOAT)
33+
tensor_a.setDimsList([1])
34+
tensor_a.setFloatDataList([a[i]])
35+
36+
const tensor_b = new onnx.TensorProto()
37+
tensor_b.setName(obj.name + '_b_' + i)
38+
tensor_b.setDataType(onnx.TensorProto.DataType.FLOAT)
39+
tensor_b.setDimsList([1])
40+
tensor_b.setFloatDataList([b[i]])
41+
42+
const node_sub = new onnx.NodeProto()
43+
node_sub.setOpType('Sub')
44+
node_sub.addInput(obj.name + '_b_' + i)
45+
node_sub.addInput(input)
46+
node_sub.addOutput(obj.name + '_sub_' + i)
47+
48+
const node_relu = new onnx.NodeProto()
49+
node_relu.setOpType('Relu')
50+
node_relu.addInput(obj.name + '_sub_' + i)
51+
node_relu.addOutput(obj.name + '_relu_' + i)
52+
53+
const node_mult = new onnx.NodeProto()
54+
node_mult.setOpType('Mul')
55+
node_mult.addInput(obj.name + '_relu_' + i)
56+
node_mult.addInput(obj.name + '_a_' + i)
57+
node_mult.addOutput(obj.name + '_mul_' + i)
58+
59+
node_add.addInput(obj.name + '_mul_' + i)
60+
61+
graph.addInitializer(tensor_a)
62+
graph.addInitializer(tensor_b)
63+
graph.addNode(node_sub)
64+
graph.addNode(node_relu)
65+
graph.addNode(node_mult)
66+
}
67+
68+
node_add.addOutput(obj.name)
69+
graph.addNode(node_add)
70+
},
71+
}

lib/model/nns/onnx/layer/aranda.js

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import { onnx } from '../onnx_exporter.js'
2+
import { getConstNodeName } from '../utils.js'
3+
4+
/**
5+
* Handle aranda layer
6+
*/
7+
export default {
8+
/**
9+
* Export to onnx object.
10+
* @param {onnx.ModelProto} model Model object
11+
* @param {import("../../graph").LayerObject & {type: 'aranda'}} obj Node object
12+
*/
13+
export(model, obj) {
14+
const input = Array.isArray(obj.input) ? obj.input[0] : obj.input
15+
const node_exp = new onnx.NodeProto()
16+
node_exp.setOpType('Exp')
17+
node_exp.addInput(input)
18+
node_exp.addOutput(obj.name + '_exp')
19+
20+
const tensor_l = new onnx.TensorProto()
21+
tensor_l.setName(obj.name + '_l')
22+
tensor_l.setDataType(onnx.TensorProto.DataType.FLOAT)
23+
tensor_l.setDimsList([1])
24+
tensor_l.setFloatDataList([obj.l ?? 2])
25+
26+
const tensor1 = getConstNodeName(model, 1)
27+
28+
const node_mult = new onnx.NodeProto()
29+
node_mult.setOpType('Mul')
30+
node_mult.addInput(obj.name + '_exp')
31+
node_mult.addInput(obj.name + '_l')
32+
node_mult.addOutput(obj.name + '_mult')
33+
34+
const node_add = new onnx.NodeProto()
35+
node_add.setOpType('Add')
36+
node_add.addInput(obj.name + '_mult')
37+
node_add.addInput(tensor1)
38+
node_add.addOutput(obj.name + '_add')
39+
40+
const node_reciprocal = new onnx.NodeProto()
41+
node_reciprocal.setOpType('Reciprocal')
42+
node_reciprocal.addInput(obj.name + '_add')
43+
node_reciprocal.addOutput(obj.name + '_reciprocal')
44+
45+
const node_reciprocal_pow = new onnx.NodeProto()
46+
node_reciprocal_pow.setOpType('Reciprocal')
47+
node_reciprocal_pow.addInput(obj.name + '_l')
48+
node_reciprocal_pow.addOutput(obj.name + '_reciprocal_pow')
49+
50+
const node_pow = new onnx.NodeProto()
51+
node_pow.setOpType('Pow')
52+
node_pow.addInput(obj.name + '_reciprocal')
53+
node_pow.addInput(obj.name + '_reciprocal_pow')
54+
node_pow.addOutput(obj.name + '_pow')
55+
56+
const node_sub = new onnx.NodeProto()
57+
node_sub.setOpType('Sub')
58+
node_sub.addInput(tensor1)
59+
node_sub.addInput(obj.name + '_pow')
60+
node_sub.addOutput(obj.name)
61+
62+
const graph = model.getGraph()
63+
graph.addInitializer(tensor_l)
64+
graph.addNode(node_exp)
65+
graph.addNode(node_mult)
66+
graph.addNode(node_add)
67+
graph.addNode(node_reciprocal)
68+
graph.addNode(node_reciprocal_pow)
69+
graph.addNode(node_pow)
70+
graph.addNode(node_sub)
71+
},
72+
}

lib/model/nns/onnx/layer/argmax.js

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import { onnx } from '../onnx_exporter.js'
2+
3+
/**
4+
* Handle argmax layer
5+
*/
6+
export default {
7+
/**
8+
* Export to onnx object.
9+
* @param {onnx.ModelProto} model Model object
10+
* @param {import("../../graph").LayerObject & {type: 'argmax'}} obj Node object
11+
* @param {{[key: string]: {type: onnx.TensorProto.DataType; size: number[]}}} info Output informatino of other layers
12+
* @returns {{type: onnx.TensorProto.DataType; size: number[]}} Output information of this layer
13+
*/
14+
export(model, obj, info) {
15+
const graph = model.getGraph()
16+
17+
const input = Array.isArray(obj.input) ? obj.input[0] : obj.input
18+
const node = new onnx.NodeProto()
19+
node.setOpType('ArgMax')
20+
node.addInput(input)
21+
node.addOutput(obj.name)
22+
23+
const axis = new onnx.AttributeProto()
24+
axis.setName('axis')
25+
axis.setType(onnx.AttributeProto.AttributeType.INT)
26+
axis.setI(obj.axis ?? -1)
27+
node.addAttribute(axis)
28+
const keepdims = new onnx.AttributeProto()
29+
keepdims.setName('keepdims')
30+
keepdims.setType(onnx.AttributeProto.AttributeType.INT)
31+
keepdims.setI(obj.keepdims ?? true ? 1 : 0)
32+
node.addAttribute(keepdims)
33+
34+
graph.addNode(node)
35+
36+
const size = info[input].size.concat()
37+
const targetAxis = axis.getI() < 0 ? axis.getI() + size.length : axis.getI()
38+
if (obj.keepdims ?? true) {
39+
size[targetAxis] = 1
40+
} else {
41+
size.splice(targetAxis, 1)
42+
}
43+
return { type: onnx.TensorProto.DataType.INT64, size }
44+
},
45+
}

0 commit comments

Comments
 (0)