Skip to content

Commit f149bac

Browse files
authored
Enhance ComputationalGraph and ONNXExporter to support numeric inputs (#947)
* Enhance ComputationalGraph and ONNXExporter to support numeric inputs * Refactor gemm and mean operators to simplify attribute handling * Fix test * Clarify parameter descriptions
1 parent 62539c8 commit f149bac

File tree

10 files changed

+184
-56
lines changed

10 files changed

+184
-56
lines changed

Diff for: lib/model/neuralnetwork.js

-18
Original file line numberDiff line numberDiff line change
@@ -44,25 +44,7 @@ export default class NeuralNetwork {
4444
if (loss) {
4545
layers.push({ type: loss })
4646
}
47-
const const_numbers = new Set()
48-
for (const l of layers) {
49-
if (l.input && Array.isArray(l.input)) {
50-
for (let i = 0; i < l.input.length; i++) {
51-
if (typeof l.input[i] === 'number') {
52-
const_numbers.add(l.input[i])
53-
l.input[i] = `__const_number_${l.input[i]}`
54-
}
55-
}
56-
}
57-
}
58-
if (const_numbers.size) {
59-
layers[0].input = []
60-
}
6147
const graph = new ComputationalGraph()
62-
for (const cn of const_numbers) {
63-
graph.add({ type: 'const', value: [[cn]] }, `__const_number_${cn}`, [])
64-
}
65-
6648
for (const l of layers) {
6749
graph.add(l, l.name, l.input)
6850
}

Diff for: lib/model/nns/graph.js

+32-15
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import Matrix from '../../util/matrix.js'
22
import { NeuralnetworkException } from '../neuralnetwork.js'
33
import Layer from './layer/base.js'
4-
import { InputLayer, OutputLayer } from './layer/index.js'
4+
import { InputLayer, OutputLayer, ConstLayer } from './layer/index.js'
55
import ONNXExporter from './onnx/onnx_exporter.js'
66
import ONNXImporter from './onnx/onnx_importer.js'
77

@@ -10,14 +10,14 @@ import ONNXImporter from './onnx/onnx_importer.js'
1010
* @typedef {import("./layer/index").PlainLayerObject} PlainLayerObject
1111
*/
1212
/**
13-
* @typedef {PlainLayerObject & {input?: string | string[], name?: string}} LayerObject
13+
* @typedef {PlainLayerObject & {input?: string | number | (string | number)[], name?: string}} LayerObject
1414
*/
1515

1616
class Node {
1717
/**
1818
* @param {string} name Name of this node
1919
* @param {PlainLayerObject} layer Layer
20-
* @param {string[]} input Input node names
20+
* @param {(string | number)[]} input Input node names
2121
* @param {{index: number, subscript: number | null}[]} parent Parent node informations
2222
* @param {ComputationalGraph} graph graph
2323
*/
@@ -37,6 +37,9 @@ class Node {
3737
const numberSubscriptRegexp = /\[([0-9]+)\]$/
3838
const stringSubscriptRegexp = /\[([a-zA-Z_][0-9a-zA-Z_]*)\]$/
3939
return (this._parent = this.input.map(p => {
40+
if (typeof p === 'number') {
41+
return new ConstLayer({ value: p })
42+
}
4043
const nm = p && p.match(numberSubscriptRegexp)
4144
const sm = p && p.match(stringSubscriptRegexp)
4245
const subscript = nm ? +nm[1] : sm ? sm[1] : null
@@ -194,14 +197,23 @@ export default class ComputationalGraph {
194197
*/
195198
toDot() {
196199
let s = 'digraph g {\n'
200+
const constNumbers = new Set()
197201
for (let i = 0; i < this._nodes.length; i++) {
198202
const node = this.nodes[i]
199203
const label = node.layer.constructor.name + (node.name ? `\\n${node.name}` : '')
200204
s += ` l${i} [label="${label}"];\n`
201205
for (const parent of node.parents) {
202-
s += ` l${parent.index} -> l${i};\n`
206+
if (parent instanceof ConstLayer) {
207+
s += ` c${parent._value} -> l${i};\n`
208+
constNumbers.add(parent._value)
209+
} else {
210+
s += ` l${parent.index} -> l${i};\n`
211+
}
203212
}
204213
}
214+
for (const cn of constNumbers) {
215+
s += ` c${cn} [label="Constant\\n${cn}"];\n`
216+
}
205217
return s + '}'
206218
}
207219

@@ -217,20 +229,20 @@ export default class ComputationalGraph {
217229
* Add a layer.
218230
* @param {Layer | PlainLayerObject} layer Added layer
219231
* @param {string} [name] Node name
220-
* @param {string[] | string} [inputs] Input node names for the added layer
232+
* @param {(string | number)[] | string | number} [inputs] Input node names or const value for the added layer
221233
*/
222234
add(layer, name, inputs = undefined) {
223235
this._order = null
224236
if (!(layer instanceof Layer)) {
225237
layer = Layer.fromObject(layer)
226238
}
227239
let parentinfos = null
228-
if (!inputs) {
240+
if (inputs == null) {
229241
parentinfos = []
230242
if (layer.calc.length > 0 && this._nodes.length > 0) {
231243
parentinfos.push({ index: this._nodes.length - 1, subscript: null })
232244
}
233-
} else if (typeof inputs === 'string') {
245+
} else if (!Array.isArray(inputs)) {
234246
inputs = [inputs]
235247
}
236248
layer.graph = this
@@ -257,11 +269,14 @@ export default class ComputationalGraph {
257269
}
258270
const s = []
259271
const outputList = Array.from(this._nodes, () => [])
260-
const addedParentCount = Array.from(this._nodes, n => n.parents.length)
272+
const addedParentCount = Array(this._nodes.length).fill(0)
261273
for (let i = 0; i < this._nodes.length; i++) {
262274
const node = this._nodes[i]
263275
for (const parent of node.parents) {
264-
outputList[parent.index].push(i)
276+
if (!(parent instanceof ConstLayer)) {
277+
outputList[parent.index].push(i)
278+
addedParentCount[i]++
279+
}
265280
}
266281
for (const name of node.layer.dependentLayers) {
267282
for (let k = 0; k < this._nodes.length; k++) {
@@ -309,11 +324,13 @@ export default class ComputationalGraph {
309324
const l = this._nodes[i]
310325
o[i] = l.layer.calc(
311326
...l.parents.map(p =>
312-
typeof p.subscript === 'number'
313-
? o[p.index][p.subscript]
314-
: p.subscript !== null
315-
? this._nodes[p.index].layer[p.subscript]
316-
: o[p.index]
327+
p instanceof ConstLayer
328+
? p.calc()
329+
: typeof p.subscript === 'number'
330+
? o[p.index][p.subscript]
331+
: p.subscript !== null
332+
? this._nodes[p.index].layer[p.subscript]
333+
: o[p.index]
317334
)
318335
)
319336
l.outputValue = o[i]
@@ -383,7 +400,7 @@ export default class ComputationalGraph {
383400
}
384401
}
385402
l.parents.forEach((p, k) => {
386-
if (!bo[k]) return
403+
if (p instanceof ConstLayer || !bo[k]) return
387404
const subidx = p.subscript || 0
388405
if (!bi[p.index][subidx]) {
389406
bi[p.index][subidx] = bo[k].copy()

Diff for: lib/model/nns/onnx/onnx_exporter.js

+18
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ export default class ONNXExporter {
5858
existNames.add(node.name)
5959
}
6060
}
61+
const constNumbers = new Set()
6162
for (let i = 0; i < ns.length; i++) {
6263
if (!ns[i].name) {
6364
const basename = `_${ns[i].type}`
@@ -73,6 +74,20 @@ export default class ONNXExporter {
7374
if (i > 0 && !ns[i].input) {
7475
ns[i].input = ns[i - 1].name
7576
}
77+
if (Array.isArray(ns[i].input)) {
78+
for (let k = 0; k < ns[i].input.length; k++) {
79+
if (typeof ns[i].input[k] === 'number') {
80+
constNumbers.add(ns[i].input[k])
81+
ns[i].input[k] = `__const_number_${ns[i].input[k]}`
82+
}
83+
}
84+
} else if (typeof ns[i].input === 'number') {
85+
constNumbers.add(ns[i].input)
86+
ns[i].input = `__const_number_${ns[i].input}`
87+
}
88+
}
89+
for (const cn of constNumbers) {
90+
ns.unshift({ type: 'const', value: cn, name: `__const_number_${cn}` })
7691
}
7792

7893
const outputInfo = {}
@@ -92,6 +107,9 @@ export default class ONNXExporter {
92107
let size = outputInfo[node.name].size
93108
for (let i = 1; i < inputs.length; i++) {
94109
const si = outputInfo[inputs[i]].size
110+
if (!si) {
111+
continue
112+
}
95113
const length = Math.max(si.length, size.length)
96114
size = Array.from({ length }, (_, i) => {
97115
const sa = size[size.length - length + i]

Diff for: lib/model/nns/onnx/operators/gemm.js

+2-4
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,9 @@ export default {
5555
weightName = inputList[1] + '_t'
5656
}
5757
if (attrs.alpha !== 1) {
58-
layers.push({ type: 'const', value: [attrs.alpha], name: inputList[1] + '_alpha' })
5958
layers.push({
6059
type: 'mult',
61-
input: [weightName, inputList[1] + '_alpha'],
60+
input: [weightName, attrs.alpha],
6261
name: inputList[1] + '_mul_a',
6362
})
6463
weightName = inputList[1] + '_mul_a'
@@ -67,10 +66,9 @@ export default {
6766
let biasName = inputList[2]
6867
if (biasName && !initializers.b) {
6968
if (attrs.beta !== 1) {
70-
layers.push({ type: 'const', value: [attrs.beta], name: inputList[2] + '_beta' })
7169
layers.push({
7270
type: 'mult',
73-
input: [biasName, inputList[2] + '_beta'],
71+
input: [biasName, attrs.beta],
7472
name: inputList[2] + '_mul_b',
7573
})
7674
biasName = inputList[2] + '_mul_b'

Diff for: lib/model/nns/onnx/operators/mean.js

+1-2
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,9 @@ export default {
1616
const outputName = node.getOutputList()[0]
1717
return [
1818
{ type: 'add', input: node.getInputList(), name: outputName + '_sum' },
19-
{ type: 'const', input: [], name: outputName + '_den', value: node.getInputList().length },
2019
{
2120
type: 'div',
22-
input: [outputName + '_sum', outputName + '_den'],
21+
input: [outputName + '_sum', node.getInputList().length],
2322
name: outputName,
2423
},
2524
]

Diff for: tests/lib/model/neuralnetwork.test.js

+5-6
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ describe('neuralnetwork', () => {
1818
})
1919

2020
test.each([
21+
[null, 'SGDOptimizer'],
2122
['sgd', 'SGDOptimizer'],
2223
['adam', 'AdamOptimizer'],
2324
['momentum', 'MomentumOptimizer'],
@@ -93,12 +94,10 @@ describe('neuralnetwork', () => {
9394
{ type: 'add', input: [1, 'in'] },
9495
])
9596

96-
expect(net._graph.nodes).toHaveLength(4)
97-
expect(net._graph.nodes[0].layer.constructor.name).toBe('ConstLayer')
98-
expect(net._graph.nodes[0].layer._value).toEqual([[1]])
99-
expect(net._graph.nodes[1].layer.constructor.name).toBe('InputLayer')
100-
expect(net._graph.nodes[2].layer.constructor.name).toBe('AddLayer')
101-
expect(net._graph.nodes[3].layer.constructor.name).toBe('OutputLayer')
97+
expect(net._graph.nodes).toHaveLength(3)
98+
expect(net._graph.nodes[0].layer.constructor.name).toBe('InputLayer')
99+
expect(net._graph.nodes[1].layer.constructor.name).toBe('AddLayer')
100+
expect(net._graph.nodes[2].layer.constructor.name).toBe('OutputLayer')
102101

103102
const x = [
104103
[1, 2],

Diff for: tests/lib/model/nns/graph.test.js

+77-2
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,9 @@ describe('Computational Graph', () => {
162162
const graph = new ComputationalGraph()
163163
graph.add(Layer.fromObject({ type: 'input' }))
164164
graph.add(Layer.fromObject({ type: 'tanh' }), 't')
165+
graph.add(Layer.fromObject({ type: 'add' }), undefined, ['t', 2])
165166
expect(graph.toDot()).toBe(
166-
'digraph g {\n l0 [label="InputLayer"];\n l1 [label="TanhLayer\\nt"];\n l0 -> l1;\n}'
167+
'digraph g {\n l0 [label="InputLayer"];\n l1 [label="TanhLayer\\nt"];\n l0 -> l1;\n l2 [label="AddLayer"];\n l1 -> l2;\n c2 -> l2;\n c2 [label="Constant\\n2"];\n}'
167168
)
168169
})
169170

@@ -318,6 +319,25 @@ describe('Computational Graph', () => {
318319
expect(graph.nodes[0].lastOutputSize).toEqual([100, 4])
319320
})
320321

322+
test('number input', () => {
323+
const graph = new ComputationalGraph()
324+
graph.add(Layer.fromObject({ type: 'input' }), 'in')
325+
graph.add(Layer.fromObject({ type: 'tanh' }), 'tanh', 2)
326+
graph.add(Layer.fromObject({ type: 'add' }), undefined, ['in', 'tanh'])
327+
328+
expect(graph.nodes[1].parents).toHaveLength(1)
329+
expect(graph.nodes[1].parents[0].constructor.name).toBe('ConstLayer')
330+
})
331+
332+
test('number array input', () => {
333+
const graph = new ComputationalGraph()
334+
graph.add(Layer.fromObject({ type: 'input' }), 'in')
335+
graph.add(Layer.fromObject({ type: 'add' }), undefined, ['in', 2])
336+
337+
expect(graph.nodes[1].parents).toHaveLength(2)
338+
expect(graph.nodes[1].parents[1].constructor.name).toBe('ConstLayer')
339+
})
340+
321341
test('invalid input name', () => {
322342
const graph = new ComputationalGraph()
323343
graph.add(Layer.fromObject({ type: 'input' }), 'in0')
@@ -371,6 +391,41 @@ describe('Computational Graph', () => {
371391
}
372392
})
373393

394+
test('add with number input', () => {
395+
const graph = new ComputationalGraph()
396+
graph.add(Layer.fromObject({ type: 'input' }), 'in')
397+
graph.add(Layer.fromObject({ type: 'tanh' }), 'tanh', 2)
398+
graph.add(Layer.fromObject({ type: 'add' }), 'op', ['in', 'tanh'])
399+
400+
const x = Matrix.randn(100, 3)
401+
graph.bind({ input: x })
402+
graph.calc()
403+
const y = graph.nodes[2].outputValue
404+
expect(y.sizes).toEqual([100, 3])
405+
for (let i = 0; i < x.rows; i++) {
406+
for (let j = 0; j < x.cols; j++) {
407+
expect(y.at(i, j)).toBe(x.at(i, j) + Math.tanh(2))
408+
}
409+
}
410+
})
411+
412+
test('add with number array input', () => {
413+
const graph = new ComputationalGraph()
414+
graph.add(Layer.fromObject({ type: 'input' }), 'in')
415+
graph.add(Layer.fromObject({ type: 'add' }), 'op', ['in', 2])
416+
417+
const x = Matrix.randn(100, 3)
418+
graph.bind({ input: x })
419+
graph.calc()
420+
const y = graph.nodes[1].outputValue
421+
expect(y.sizes).toEqual([100, 3])
422+
for (let i = 0; i < x.rows; i++) {
423+
for (let j = 0; j < x.cols; j++) {
424+
expect(y.at(i, j)).toBe(x.at(i, j) + 2)
425+
}
426+
}
427+
})
428+
374429
test('require', () => {
375430
const graph = new ComputationalGraph()
376431
graph.add(Layer.fromObject({ type: 'input' }), 'l0')
@@ -541,6 +596,25 @@ describe('Computational Graph', () => {
541596
}
542597
})
543598

599+
test('mult with number input', () => {
600+
const graph = new ComputationalGraph()
601+
graph.add(Layer.fromObject({ type: 'input' }), 'in')
602+
graph.add(Layer.fromObject({ type: 'mult' }), 'op', ['in', 2])
603+
graph.add(Layer.fromObject({ type: 'output' }))
604+
605+
const x = Matrix.randn(100, 3)
606+
graph.bind({ input: x })
607+
graph.calc()
608+
graph.grad(Matrix.ones(100, 3))
609+
const g = graph.nodes[0].gradientValue[0]
610+
expect(g.sizes).toEqual([100, 3])
611+
for (let i = 0; i < x.rows; i++) {
612+
for (let j = 0; j < x.cols; j++) {
613+
expect(g.at(i, j)).toBe(2)
614+
}
615+
}
616+
})
617+
544618
test.each([0, 1])('subscript input %d', i => {
545619
const graph = new ComputationalGraph()
546620
graph.add(Layer.fromObject({ type: 'input' }))
@@ -594,8 +668,9 @@ describe('Computational Graph', () => {
594668
test('layer after output without grad', () => {
595669
const graph = new ComputationalGraph()
596670
graph.add(Layer.fromObject({ type: 'input' }))
597-
graph.add(Layer.fromObject({ type: 'output' }))
671+
graph.add(Layer.fromObject({ type: 'output' }), 'out')
598672
graph.add(Layer.fromObject({ type: 'tanh' }))
673+
graph.add(Layer.fromObject({ type: 'tanh' }), undefined, ['out'])
599674

600675
const x = Matrix.randn(100, 3)
601676
graph.bind({ input: x })

0 commit comments

Comments
 (0)