Skip to content

Commit b4d4617

Browse files
authored
Add as constant layer if input is initializer (#934)
1 parent 8dd5bba commit b4d4617

File tree

11 files changed

+60
-53
lines changed

11 files changed

+60
-53
lines changed

lib/model/nns/onnx/onnx_importer.js

+17
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import input from './operators/input.js'
55
import output from './operators/output.js'
66

77
import * as operators from './operators/index.js'
8+
import { loadTensor } from './utils.js'
89

910
/**
1011
* ONNX importer
@@ -41,6 +42,22 @@ export default class ONNXImporter {
4142
for (const node of graph.getOutputList()) {
4243
nodes.push(...output.import(model, node))
4344
}
45+
46+
const importedNames = new Set(nodes.map(n => n.name))
47+
const inputNames = new Map(graph.getInitializerList().map(init => [init.getName(), init]))
48+
for (const node of nodes.filter(n => n.input)) {
49+
const inputs = Array.isArray(node.input) ? node.input : [node.input]
50+
for (const i of inputs) {
51+
if (importedNames.has(i)) {
52+
continue
53+
}
54+
const initializer = inputNames.get(i)
55+
if (initializer) {
56+
nodes.push({ type: 'const', name: i, value: loadTensor(initializer) })
57+
importedNames.add(i)
58+
}
59+
}
60+
}
4461
return nodes
4562
}
4663
}

lib/model/nns/onnx/operators/add.js

+1-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import { onnx } from '../onnx_importer.js'
2-
import { requireTensor } from '../utils.js'
32

43
/**
54
* Handle add operator
@@ -14,9 +13,6 @@ export default {
1413
* @returns {object[]} Objects represented a layer
1514
*/
1615
import(model, node) {
17-
return [
18-
...requireTensor(model, node.getInputList()),
19-
{ type: 'add', input: node.getInputList(), name: node.getOutputList()[0] },
20-
]
16+
return [{ type: 'add', input: node.getInputList(), name: node.getOutputList()[0] }]
2117
},
2218
}

lib/model/nns/onnx/operators/div.js

+1-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import { onnx } from '../onnx_importer.js'
2-
import { requireTensor } from '../utils.js'
32

43
/**
54
* Handle div operator
@@ -14,9 +13,6 @@ export default {
1413
* @returns {object[]} Objects represented a layer
1514
*/
1615
import(model, node) {
17-
return [
18-
...requireTensor(model, node.getInputList()),
19-
{ type: 'div', input: node.getInputList(), name: node.getOutputList()[0] },
20-
]
16+
return [{ type: 'div', input: node.getInputList(), name: node.getOutputList()[0] }]
2117
},
2218
}

lib/model/nns/onnx/operators/matmul.js

+1-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import { onnx } from '../onnx_importer.js'
2-
import { requireTensor } from '../utils.js'
32

43
/**
54
* Handle matmul operator
@@ -14,9 +13,6 @@ export default {
1413
* @returns {object[]} Objects represented a layer
1514
*/
1615
import(model, node) {
17-
return [
18-
...requireTensor(model, node.getInputList()),
19-
{ type: 'matmul', input: node.getInputList(), name: node.getOutputList()[0] },
20-
]
16+
return [{ type: 'matmul', input: node.getInputList(), name: node.getOutputList()[0] }]
2117
},
2218
}

lib/model/nns/onnx/operators/max.js

+1-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import { onnx } from '../onnx_importer.js'
2-
import { requireTensor } from '../utils.js'
32

43
/**
54
* Handle max operator
@@ -14,9 +13,6 @@ export default {
1413
* @returns {object[]} Objects represented a layer
1514
*/
1615
import(model, node) {
17-
return [
18-
...requireTensor(model, node.getInputList()),
19-
{ type: 'max', input: node.getInputList(), name: node.getOutputList()[0] },
20-
]
16+
return [{ type: 'max', input: node.getInputList(), name: node.getOutputList()[0] }]
2117
},
2218
}

lib/model/nns/onnx/operators/min.js

+1-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import { onnx } from '../onnx_importer.js'
2-
import { requireTensor } from '../utils.js'
32

43
/**
54
* Handle min operator
@@ -14,9 +13,6 @@ export default {
1413
* @returns {object[]} Objects represented a layer
1514
*/
1615
import(model, node) {
17-
return [
18-
...requireTensor(model, node.getInputList()),
19-
{ type: 'min', input: node.getInputList(), name: node.getOutputList()[0] },
20-
]
16+
return [{ type: 'min', input: node.getInputList(), name: node.getOutputList()[0] }]
2117
},
2218
}

lib/model/nns/onnx/operators/mul.js

+1-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import { onnx } from '../onnx_importer.js'
2-
import { requireTensor } from '../utils.js'
32

43
/**
54
* Handle mul operator
@@ -14,9 +13,6 @@ export default {
1413
* @returns {object[]} Objects represented a layer
1514
*/
1615
import(model, node) {
17-
return [
18-
...requireTensor(model, node.getInputList()),
19-
{ type: 'mult', input: node.getInputList(), name: node.getOutputList()[0] },
20-
]
16+
return [{ type: 'mult', input: node.getInputList(), name: node.getOutputList()[0] }]
2117
},
2218
}

lib/model/nns/onnx/operators/reshape.js

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { onnx } from '../onnx_importer.js'
2-
import { loadTensor, loadAttribute, requireTensor } from '../utils.js'
2+
import { loadTensor, loadAttribute } from '../utils.js'
33

44
/**
55
* Handle reshape operator
@@ -31,7 +31,6 @@ export default {
3131
throw new Error(`Invalid shape value ${JSON.stringify(initializers.shape)}.`)
3232
}
3333
return [
34-
...requireTensor(model, inputList[0]),
3534
{
3635
type: 'reshape',
3736
input: [inputList[0]],

lib/model/nns/onnx/operators/sub.js

+1-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import { onnx } from '../onnx_importer.js'
2-
import { requireTensor } from '../utils.js'
32

43
/**
54
* Handle sub operator
@@ -14,9 +13,6 @@ export default {
1413
* @returns {object[]} Objects represented a layer
1514
*/
1615
import(model, node) {
17-
return [
18-
...requireTensor(model, node.getInputList()),
19-
{ type: 'sub', input: node.getInputList(), name: node.getOutputList()[0] },
20-
]
16+
return [{ type: 'sub', input: node.getInputList(), name: node.getOutputList()[0] }]
2117
},
2218
}

lib/model/nns/onnx/utils.js

-16
Original file line numberDiff line numberDiff line change
@@ -118,22 +118,6 @@ export const loadAttribute = attribute => {
118118
throw new Error('Not implemented attribute type.')
119119
}
120120

121-
/**
122-
* Create const layers from initializer list.
123-
* @param {onnx.ModelProto} model Model object
124-
* @param {string[]} names Input name
125-
* @returns {object[]} Require layer objects
126-
*/
127-
export const requireTensor = (model, names) => {
128-
const layers = []
129-
for (const initializer of model.getGraph().getInitializerList()) {
130-
if (names.includes(initializer.getName())) {
131-
layers.push({ type: 'const', value: loadTensor(initializer), name: initializer.getName() })
132-
}
133-
}
134-
return layers
135-
}
136-
137121
/**
138122
* Create const node if needed and return const node name.
139123
* @param {onnx.ModelProto} model Model object

tests/lib/model/nns/graph.test.js

+35
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import Matrix from '../../../../lib/util/matrix.js'
1111
import ComputationalGraph from '../../../../lib/model/nns/graph.js'
1212

1313
import Layer from '../../../../lib/model/nns/layer/base.js'
14+
import Tensor from '../../../../lib/util/tensor.js'
1415

1516
describe('Computational Graph', () => {
1617
test('constructor', () => {
@@ -193,6 +194,40 @@ describe('Computational Graph', () => {
193194
expect(y[i]).toBeCloseTo(Math.tanh(x.value[i]))
194195
}
195196
})
197+
198+
test('complex layers', async () => {
199+
const graph = new ComputationalGraph()
200+
graph.add(Layer.fromObject({ type: 'input', size: [null, 6, 6, 3] }))
201+
graph.add(Layer.fromObject({ type: 'conv', kernel: 3 }))
202+
graph.add(Layer.fromObject({ type: 'max_pool', kernel: 2 }))
203+
graph.add(Layer.fromObject({ type: 'relu' }))
204+
graph.add(Layer.fromObject({ type: 'flatten' }))
205+
graph.add(Layer.fromObject({ type: 'full', out_size: 10 }))
206+
graph.add(Layer.fromObject({ type: 'tanh' }), 'v')
207+
graph.add(Layer.fromObject({ type: 'pau' }))
208+
graph.add(Layer.fromObject({ type: 'tanh' }), 'pau')
209+
graph.add(Layer.fromObject({ type: 'apl' }), 'apl', 'v')
210+
graph.add(Layer.fromObject({ type: 'add' }), null, ['pau', 'apl'])
211+
graph.add(Layer.fromObject({ type: 'output' }))
212+
213+
const x = Tensor.randn([100, 6, 6, 3])
214+
graph.bind({ input: x })
215+
graph.calc()
216+
const t = graph.outputNodes[0].outputValue
217+
218+
const buf = await graph.toONNX()
219+
session = await ort.InferenceSession.create(buf)
220+
221+
const xten = new ort.Tensor('float32', x.value, x.sizes)
222+
const out = await session.run({ _input: xten })
223+
const yten = out._add
224+
expect(yten.dims).toEqual([100, 10])
225+
const y = await yten.getData(true)
226+
227+
for (let i = 0; i < y.length; i++) {
228+
expect(y[i]).toBeCloseTo(t.value[i])
229+
}
230+
})
196231
})
197232

198233
describe('add', () => {

0 commit comments

Comments
 (0)