Skip to content

Commit 8dd5bba

Browse files
authored
Sort and calculate according to dependencies (#933)
* Sort and calculate according to dependencies * Remove unnecessary imports
1 parent 109c76e commit 8dd5bba

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+687
-63
lines changed

lib/model/nns/graph.js

+146-59
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,106 @@ import ONNXImporter from './onnx/onnx_importer.js'
1111
*/
1212
/**
1313
* @typedef {PlainLayerObject & {input?: string | string[], name?: string}} LayerObject
14-
* @typedef {object} Node
15-
* @property {Layer} layer Layer
16-
* @property {string} name Name of the node
17-
* @property {string[]} [input] Input node names
18-
* @property {{index: number, subscript?: number}[]} parents Parent node informations
19-
* @property {Matrix | Matrix[]} [outputValue] Output value of this node to next layer
20-
* @property {Matrix[]} [gradientValue] Gradient value of this node from next layer
2114
*/
15+
16+
class Node {
17+
/**
18+
* @param {string} name Name of this node
19+
* @param {PlainLayerObject} layer Layer
20+
* @param {string[]} input Input node names
21+
* @param {{index: number, subscript: number | null}[]} parent Parent node informations
22+
* @param {ComputationalGraph} graph graph
23+
*/
24+
constructor(name, layer, input, parent, graph) {
25+
this.name = name
26+
this.layer = layer
27+
this.input = input
28+
29+
this._graph = graph
30+
this._parent = parent
31+
}
32+
33+
get parents() {
34+
if (this._parent) {
35+
return this._parent
36+
}
37+
const numberSubscriptRegexp = /\[([0-9]+)\]$/
38+
const stringSubscriptRegexp = /\[([a-zA-Z_][0-9a-zA-Z_]*)\]$/
39+
return (this._parent = this.input.map(p => {
40+
const nm = p && p.match(numberSubscriptRegexp)
41+
const sm = p && p.match(stringSubscriptRegexp)
42+
const subscript = nm ? +nm[1] : sm ? sm[1] : null
43+
const pname = nm || sm ? p.slice(0, -(nm || sm)[0].length) : p
44+
for (let k = 0; k < this._graph._nodes.length; k++) {
45+
if (this._graph._nodes[k].name === pname) {
46+
return { index: k, subscript }
47+
}
48+
}
49+
throw new NeuralnetworkException(`Unknown input name '${p}'.`)
50+
}))
51+
}
52+
53+
/**
54+
* Output value of this node to next layer
55+
* @type {Matrix | Matrix[]}
56+
*/
57+
get outputValue() {
58+
return this._outputValue
59+
}
60+
61+
/**
62+
* @param {Matrix | Matrix[]} value Output value of this node to next layer
63+
*/
64+
set outputValue(value) {
65+
this._outputValue = value
66+
}
67+
68+
/**
69+
* Gradient value of this node from next layer
70+
* @type {Matrix[]}
71+
*/
72+
get gradientValue() {
73+
return this._gradientValue
74+
}
75+
76+
/**
77+
* @param {Matrix[]} value Gradient value of this node from next layer
78+
*/
79+
set gradientValue(value) {
80+
this._gradientValue = value
81+
}
82+
83+
/**
84+
* Output value size
85+
* @type {number[]}
86+
*/
87+
get lastOutputSize() {
88+
return this.outputValue.sizes
89+
}
90+
91+
/**
92+
* Returns object representation.
93+
* @returns {LayerObject} Object represented this node
94+
*/
95+
toObject() {
96+
const obj = this.layer.toObject()
97+
if (this.name) {
98+
obj.name = this.name
99+
}
100+
if (this.input) {
101+
obj.input = this.input
102+
}
103+
return obj
104+
}
105+
}
106+
22107
/**
23108
* Computational graph for Neuralnetwork structure
24109
*/
25110
export default class ComputationalGraph {
26111
constructor() {
27112
this._nodes = []
113+
this._order = null
28114
}
29115

30116
/**
@@ -99,19 +185,7 @@ export default class ComputationalGraph {
99185
* @returns {LayerObject[]} Object represented this graph
100186
*/
101187
toObject() {
102-
const s = []
103-
for (let i = 0; i < this._nodes.length; i++) {
104-
const node = this._nodes[i]
105-
const obj = node.layer.toObject()
106-
if (node.name) {
107-
obj.name = node.name
108-
}
109-
if (node.input) {
110-
obj.input = node.input
111-
}
112-
s.push(obj)
113-
}
114-
return s
188+
return this._nodes.map(node => node.toObject())
115189
}
116190

117191
/**
@@ -146,51 +220,21 @@ export default class ComputationalGraph {
146220
* @param {string[] | string} [inputs] Input node names for the added layer
147221
*/
148222
add(layer, name, inputs = undefined) {
223+
this._order = null
149224
if (!(layer instanceof Layer)) {
150225
layer = Layer.fromObject(layer)
151226
}
152-
let parentinfos = []
227+
let parentinfos = null
153228
if (!inputs) {
229+
parentinfos = []
154230
if (layer.calc.length > 0 && this._nodes.length > 0) {
155-
parentinfos.push({
156-
index: this._nodes.length - 1,
157-
subscript: null,
158-
})
159-
}
160-
} else {
161-
if (typeof inputs === 'string') {
162-
inputs = [inputs]
231+
parentinfos.push({ index: this._nodes.length - 1, subscript: null })
163232
}
164-
const numberSubscriptRegexp = /\[([0-9]+)\]$/
165-
const stringSubscriptRegexp = /\[([a-zA-Z_][0-9a-zA-Z_]*)\]$/
166-
parentinfos = inputs.map(p => {
167-
const nm = p && p.match(numberSubscriptRegexp)
168-
const sm = p && p.match(stringSubscriptRegexp)
169-
const subscript = nm ? +nm[1] : sm ? sm[1] : null
170-
const pname = nm || sm ? p.slice(0, -(nm || sm)[0].length) : p
171-
for (let k = 0; k < this._nodes.length; k++) {
172-
if (this._nodes[k].name === pname) {
173-
return {
174-
index: k,
175-
subscript,
176-
}
177-
}
178-
}
179-
throw new NeuralnetworkException(`Unknown input name '${p}'.`)
180-
})
233+
} else if (typeof inputs === 'string') {
234+
inputs = [inputs]
181235
}
182236
layer.graph = this
183-
const node = {
184-
layer,
185-
name,
186-
input: inputs,
187-
parents: parentinfos,
188-
outputValue: null,
189-
gradientValue: null,
190-
get lastOutputSize() {
191-
return this.outputValue.sizes
192-
},
193-
}
237+
const node = new Node(name, layer, inputs, parentinfos, this)
194238
if (name && this.getNode(name)) {
195239
throw new NeuralnetworkException(`Duplicate layer name ${name}`)
196240
}
@@ -207,18 +251,60 @@ export default class ComputationalGraph {
207251
}
208252
}
209253

254+
_calcOrder() {
255+
if (this._order) {
256+
return
257+
}
258+
const s = []
259+
const outputList = Array.from(this._nodes, () => [])
260+
const addedParentCount = Array.from(this._nodes, n => n.parents.length)
261+
for (let i = 0; i < this._nodes.length; i++) {
262+
const node = this._nodes[i]
263+
for (const parent of node.parents) {
264+
outputList[parent.index].push(i)
265+
}
266+
for (const name of node.layer.dependentLayers) {
267+
for (let k = 0; k < this._nodes.length; k++) {
268+
if (i !== k && this._nodes[k].name === name && !outputList[k].includes(i)) {
269+
outputList[k].push(i)
270+
addedParentCount[i]++
271+
break
272+
}
273+
}
274+
}
275+
if (addedParentCount[i] === 0) {
276+
s.push(i)
277+
}
278+
}
279+
this._order = []
280+
while (s.length > 0) {
281+
const n = s.pop()
282+
this._order.push(n)
283+
for (const i of outputList[n]) {
284+
addedParentCount[i]--
285+
if (addedParentCount[i] === 0) {
286+
s.push(i)
287+
}
288+
}
289+
}
290+
if (addedParentCount.some(v => v !== 0)) {
291+
throw new Error('This graph is not directed acyclic graph.')
292+
}
293+
}
294+
210295
/**
211296
* Returns calculated values.
212297
* @param {(string | number)[]} [require] Name or index of nodes at least calculated
213298
*/
214299
calc(require) {
300+
this._calcOrder()
215301
for (let i = 0; i < this._nodes.length; i++) {
216302
this._nodes[i].outputValue = null
217303
this._nodes[i].gradientValue = null
218304
}
219305
const o = []
220306
const r = require ? Array(require.length).fill(false) : []
221-
for (let i = 0; i < this._nodes.length; i++) {
307+
for (const i of this._order) {
222308
try {
223309
const l = this._nodes[i]
224310
o[i] = l.layer.calc(
@@ -255,13 +341,14 @@ export default class ComputationalGraph {
255341
* @returns {Matrix} Output of gradient
256342
*/
257343
grad(e) {
344+
this._calcOrder()
258345
const n = this._nodes.length
259346
const bi = Array.from(this._nodes, () => [])
260347
const initGrad = this._nodes[n - 1].outputValue?.copy() ?? new Matrix(1, 1)
261348
initGrad.fill(1)
262349
bi[n - 1] = [initGrad]
263350
let bi_input = null
264-
for (let i = n - 1; i >= 0; i--) {
351+
for (const i of [...this._order].reverse()) {
265352
const l = this._nodes[i]
266353
if (e) {
267354
if (l.layer instanceof OutputLayer) {

lib/model/nns/layer/attention.js

+14
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,20 @@ export default class AttentionLayer extends Layer {
3838
}
3939
}
4040

41+
get dependentLayers() {
42+
const layers = []
43+
if (this._wqname) {
44+
layers.push(this._wqname)
45+
}
46+
if (this._wkname) {
47+
layers.push(this._wkname)
48+
}
49+
if (this._wvname) {
50+
layers.push(this._wvname)
51+
}
52+
return layers
53+
}
54+
4155
calc(x, memory) {
4256
this._selfattention = !memory
4357
if (!memory) {

lib/model/nns/layer/base.js

+8
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,14 @@ export default class Layer {
6060
layerClasses[name] = cls
6161
}
6262

63+
/**
64+
* List of names of other layers dependent on this layer.
65+
* @type {string[]}
66+
*/
67+
get dependentLayers() {
68+
return []
69+
}
70+
6371
/**
6472
* Bind pre-condition values.
6573
* @param {object} values Binding object

lib/model/nns/layer/batch_normalization.js

+17
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,23 @@ export default class BatchNormalizationLayer extends Layer {
3737
this._input_var = input_var
3838
}
3939

40+
get dependentLayers() {
41+
const layers = []
42+
if (this._scalename) {
43+
layers.push(this._scalename)
44+
}
45+
if (this._offsetname) {
46+
layers.push(this._offsetname)
47+
}
48+
if (typeof this._input_mean === 'string') {
49+
layers.push(this._input_mean)
50+
}
51+
if (typeof this._input_var === 'string') {
52+
layers.push(this._input_var)
53+
}
54+
return layers
55+
}
56+
4057
get mean() {
4158
return this._mean
4259
}

lib/model/nns/layer/clip.js

+11
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,17 @@ export default class ClipLayer extends Layer {
1515
this._max = max
1616
}
1717

18+
get dependentLayers() {
19+
const layers = []
20+
if (typeof this._min === 'string') {
21+
layers.push(this._min)
22+
}
23+
if (typeof this._max === 'string') {
24+
layers.push(this._max)
25+
}
26+
return layers
27+
}
28+
1829
calc(x) {
1930
const min = typeof this._min === 'string' ? this.graph.getNode(this._min).outputValue.toScaler() : this._min
2031
const max = typeof this._max === 'string' ? this.graph.getNode(this._max).outputValue.toScaler() : this._max

lib/model/nns/layer/conv.js

+11
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,17 @@ export default class ConvLayer extends Layer {
5959
this._l1_decay = l1_decay
6060
}
6161

62+
get dependentLayers() {
63+
const layers = []
64+
if (this._wname) {
65+
layers.push(this._wname)
66+
}
67+
if (this._activation) {
68+
layers.push(...this._activation.dependentLayers)
69+
}
70+
return layers
71+
}
72+
6273
_index(i, c, k) {
6374
return this._channel_dim === -1 ? [i, ...k, c] : [i, c, ...k]
6475
}

lib/model/nns/layer/full.js

+14
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,20 @@ export default class FullyConnected extends Layer {
3838
this._l1_decay = l1_decay
3939
}
4040

41+
get dependentLayers() {
42+
const layers = []
43+
if (this._wname) {
44+
layers.push(this._wname)
45+
}
46+
if (this._bname) {
47+
layers.push(this._bname)
48+
}
49+
if (this._activation) {
50+
layers.push(...this._activation.dependentLayers)
51+
}
52+
return layers
53+
}
54+
4155
calc(x) {
4256
if (this._wname) {
4357
this._w = this.graph.getNode(this._wname).outputValue

0 commit comments

Comments
 (0)