Skip to content

Commit 3b9c16b

Browse files
authored
Fix comparison layers and update related tests (#938)
1 parent 56d7368 commit 3b9c16b

17 files changed

+629
-130
lines changed

create_import_list.js

+3-1
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,9 @@ const createLayerlist = async () => {
134134
source.forEachChild(node => {
135135
if (
136136
ts.isVariableStatement(node) &&
137-
['unaryLayers', 'binaryLayers'].includes(node.declarationList.declarations[0].name.escapedText)
137+
['unaryLayers', 'binaryLayers', 'compareLayers'].includes(
138+
node.declarationList.declarations[0].name.escapedText
139+
)
138140
) {
139141
const init = node.declarationList.declarations[0].initializer
140142
for (const property of init.properties) {

lib/model/nns/layer/base.js

+41-20
Original file line numberDiff line numberDiff line change
@@ -469,30 +469,10 @@ const binaryLayers = {
469469
return v
470470
},
471471
},
472-
equal: {
473-
calc: (x1, x2) => x1 === x2,
474-
gradFunc: () => 0,
475-
},
476-
greater: {
477-
calc: (x1, x2) => x1 > x2,
478-
gradFunc: () => 0,
479-
},
480-
greater_or_equal: {
481-
calc: (x1, x2) => x1 >= x2,
482-
gradFunc: () => 0,
483-
},
484472
left_bitshift: {
485473
calc: (x1, x2) => x1 << x2,
486474
gradFunc: () => 0,
487475
},
488-
less: {
489-
calc: (x1, x2) => x1 < x2,
490-
gradFunc: () => 0,
491-
},
492-
less_or_equal: {
493-
calc: (x1, x2) => x1 <= x2,
494-
gradFunc: () => 0,
495-
},
496476
max: {
497477
calc: Math.max,
498478
gradFunc: (k, x) => {
@@ -563,3 +543,44 @@ const binaryLayers = {
563543
for (const name of Object.keys(binaryLayers)) {
564544
buildBinaryLayer(name, binaryLayers[name].calc, binaryLayers[name].gradFunc)
565545
}
546+
547+
const buildCompareLayer = (name, calcFunc) => {
548+
class TempLayer extends Layer {
549+
calc(...x) {
550+
this._i = x
551+
this._o = x[0].copy()
552+
this._o.map(() => true)
553+
for (let i = 1; i < x.length; i++) {
554+
const xi = x[i - 1].copy()
555+
xi.broadcastOperate(x[i], calcFunc)
556+
this._o.broadcastOperate(xi, (a, b) => a && b)
557+
}
558+
return this._o
559+
}
560+
561+
grad() {
562+
const bi = this._i.map(x => {
563+
const bi = x.copy()
564+
bi.fill(0)
565+
return bi
566+
})
567+
return bi
568+
}
569+
}
570+
Object.defineProperty(TempLayer, 'name', {
571+
value: name.split('_').reduce((s, nm) => s + nm[0].toUpperCase() + nm.substring(1).toLowerCase(), '') + 'Layer',
572+
})
573+
TempLayer.registLayer(name)
574+
}
575+
576+
const compareLayers = {
577+
equal: { calc: (x1, x2) => x1 === x2 },
578+
greater: { calc: (x1, x2) => x1 > x2 },
579+
greater_or_equal: { calc: (x1, x2) => x1 >= x2 },
580+
less: { calc: (x1, x2) => x1 < x2 },
581+
less_or_equal: { calc: (x1, x2) => x1 <= x2 },
582+
}
583+
584+
for (const name of Object.keys(compareLayers)) {
585+
buildCompareLayer(name, compareLayers[name].calc)
586+
}

lib/model/nns/onnx/layer/equal.js

+45-3
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,51 @@ export default {
1616
}
1717
const graph = model.getGraph()
1818
const node = new onnx.NodeProto()
19-
node.setOpType('Equal')
20-
for (const i of obj.input) {
21-
node.addInput(i)
19+
if (obj.input.length === 1) {
20+
const node_shape = new onnx.NodeProto()
21+
node_shape.setOpType('Shape')
22+
node_shape.addInput(obj.input[0])
23+
node_shape.addOutput(obj.name + '_shape')
24+
graph.addNode(node_shape)
25+
26+
node.setOpType('ConstantOfShape')
27+
node.addInput(obj.name + '_shape')
28+
29+
const tensor = new onnx.TensorProto()
30+
tensor.setDataType(onnx.TensorProto.DataType.BOOL)
31+
tensor.setDimsList([1])
32+
tensor.setInt32DataList([1])
33+
const attrValue = new onnx.AttributeProto()
34+
attrValue.setName('value')
35+
attrValue.setType(onnx.AttributeProto.AttributeType.TENSOR)
36+
attrValue.setT(tensor)
37+
node.addAttribute(attrValue)
38+
} else if (obj.input.length === 2) {
39+
node.setOpType('Equal')
40+
node.addInput(obj.input[0])
41+
node.addInput(obj.input[1])
42+
} else {
43+
for (let i = 0; i < obj.input.length - 1; i++) {
44+
const node_equal = new onnx.NodeProto()
45+
node_equal.setOpType('Equal')
46+
node_equal.addInput(obj.input[i])
47+
node_equal.addInput(obj.input[i + 1])
48+
node_equal.addOutput(`${obj.name}_eq_${i}`)
49+
graph.addNode(node_equal)
50+
}
51+
let prev_in = obj.name + '_eq_0'
52+
for (let i = 1; i < obj.input.length - 2; i++) {
53+
const node_mul = new onnx.NodeProto()
54+
node_mul.setOpType('And')
55+
node_mul.addInput(prev_in)
56+
node_mul.addInput(`${obj.name}_eq_${i}`)
57+
node_mul.addOutput((prev_in = obj.name + `_and_${i - 1}`))
58+
graph.addNode(node_mul)
59+
}
60+
61+
node.setOpType('And')
62+
node.addInput(prev_in)
63+
node.addInput(`${obj.name}_eq_${obj.input.length - 2}`)
2264
}
2365
node.addOutput(obj.name)
2466
graph.addNode(node)

lib/model/nns/onnx/layer/greater.js

+45-4
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,52 @@ export default {
1515
throw new Error(`Invalid attribute 'input' value ${obj.input}.`)
1616
}
1717
const graph = model.getGraph()
18-
1918
const node = new onnx.NodeProto()
20-
node.setOpType('Greater')
21-
for (const i of obj.input) {
22-
node.addInput(i)
19+
if (obj.input.length === 1) {
20+
const node_shape = new onnx.NodeProto()
21+
node_shape.setOpType('Shape')
22+
node_shape.addInput(obj.input[0])
23+
node_shape.addOutput(obj.name + '_shape')
24+
graph.addNode(node_shape)
25+
26+
node.setOpType('ConstantOfShape')
27+
node.addInput(obj.name + '_shape')
28+
29+
const tensor = new onnx.TensorProto()
30+
tensor.setDataType(onnx.TensorProto.DataType.BOOL)
31+
tensor.setDimsList([1])
32+
tensor.setInt32DataList([1])
33+
const attrValue = new onnx.AttributeProto()
34+
attrValue.setName('value')
35+
attrValue.setType(onnx.AttributeProto.AttributeType.TENSOR)
36+
attrValue.setT(tensor)
37+
node.addAttribute(attrValue)
38+
} else if (obj.input.length === 2) {
39+
node.setOpType('Greater')
40+
node.addInput(obj.input[0])
41+
node.addInput(obj.input[1])
42+
} else {
43+
for (let i = 0; i < obj.input.length - 1; i++) {
44+
const node_equal = new onnx.NodeProto()
45+
node_equal.setOpType('Greater')
46+
node_equal.addInput(obj.input[i])
47+
node_equal.addInput(obj.input[i + 1])
48+
node_equal.addOutput(`${obj.name}_gt_${i}`)
49+
graph.addNode(node_equal)
50+
}
51+
let prev_in = obj.name + '_gt_0'
52+
for (let i = 1; i < obj.input.length - 2; i++) {
53+
const node_mul = new onnx.NodeProto()
54+
node_mul.setOpType('And')
55+
node_mul.addInput(prev_in)
56+
node_mul.addInput(`${obj.name}_gt_${i}`)
57+
node_mul.addOutput((prev_in = obj.name + `_and_${i - 1}`))
58+
graph.addNode(node_mul)
59+
}
60+
61+
node.setOpType('And')
62+
node.addInput(prev_in)
63+
node.addInput(`${obj.name}_gt_${obj.input.length - 2}`)
2364
}
2465
node.addOutput(obj.name)
2566
graph.addNode(node)

lib/model/nns/onnx/layer/greater_or_equal.js

+45-4
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,52 @@ export default {
1515
throw new Error(`Invalid attribute 'input' value ${obj.input}.`)
1616
}
1717
const graph = model.getGraph()
18-
1918
const node = new onnx.NodeProto()
20-
node.setOpType('GreaterOrEqual')
21-
for (const i of obj.input) {
22-
node.addInput(i)
19+
if (obj.input.length === 1) {
20+
const node_shape = new onnx.NodeProto()
21+
node_shape.setOpType('Shape')
22+
node_shape.addInput(obj.input[0])
23+
node_shape.addOutput(obj.name + '_shape')
24+
graph.addNode(node_shape)
25+
26+
node.setOpType('ConstantOfShape')
27+
node.addInput(obj.name + '_shape')
28+
29+
const tensor = new onnx.TensorProto()
30+
tensor.setDataType(onnx.TensorProto.DataType.BOOL)
31+
tensor.setDimsList([1])
32+
tensor.setInt32DataList([1])
33+
const attrValue = new onnx.AttributeProto()
34+
attrValue.setName('value')
35+
attrValue.setType(onnx.AttributeProto.AttributeType.TENSOR)
36+
attrValue.setT(tensor)
37+
node.addAttribute(attrValue)
38+
} else if (obj.input.length === 2) {
39+
node.setOpType('GreaterOrEqual')
40+
node.addInput(obj.input[0])
41+
node.addInput(obj.input[1])
42+
} else {
43+
for (let i = 0; i < obj.input.length - 1; i++) {
44+
const node_equal = new onnx.NodeProto()
45+
node_equal.setOpType('GreaterOrEqual')
46+
node_equal.addInput(obj.input[i])
47+
node_equal.addInput(obj.input[i + 1])
48+
node_equal.addOutput(`${obj.name}_ge_${i}`)
49+
graph.addNode(node_equal)
50+
}
51+
let prev_in = obj.name + '_ge_0'
52+
for (let i = 1; i < obj.input.length - 2; i++) {
53+
const node_mul = new onnx.NodeProto()
54+
node_mul.setOpType('And')
55+
node_mul.addInput(prev_in)
56+
node_mul.addInput(`${obj.name}_ge_${i}`)
57+
node_mul.addOutput((prev_in = obj.name + `_and_${i - 1}`))
58+
graph.addNode(node_mul)
59+
}
60+
61+
node.setOpType('And')
62+
node.addInput(prev_in)
63+
node.addInput(`${obj.name}_ge_${obj.input.length - 2}`)
2364
}
2465
node.addOutput(obj.name)
2566
graph.addNode(node)

lib/model/nns/onnx/layer/less.js

+45-4
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,52 @@ export default {
1515
throw new Error(`Invalid attribute 'input' value ${obj.input}.`)
1616
}
1717
const graph = model.getGraph()
18-
1918
const node = new onnx.NodeProto()
20-
node.setOpType('Less')
21-
for (const i of obj.input) {
22-
node.addInput(i)
19+
if (obj.input.length === 1) {
20+
const node_shape = new onnx.NodeProto()
21+
node_shape.setOpType('Shape')
22+
node_shape.addInput(obj.input[0])
23+
node_shape.addOutput(obj.name + '_shape')
24+
graph.addNode(node_shape)
25+
26+
node.setOpType('ConstantOfShape')
27+
node.addInput(obj.name + '_shape')
28+
29+
const tensor = new onnx.TensorProto()
30+
tensor.setDataType(onnx.TensorProto.DataType.BOOL)
31+
tensor.setDimsList([1])
32+
tensor.setInt32DataList([1])
33+
const attrValue = new onnx.AttributeProto()
34+
attrValue.setName('value')
35+
attrValue.setType(onnx.AttributeProto.AttributeType.TENSOR)
36+
attrValue.setT(tensor)
37+
node.addAttribute(attrValue)
38+
} else if (obj.input.length === 2) {
39+
node.setOpType('Less')
40+
node.addInput(obj.input[0])
41+
node.addInput(obj.input[1])
42+
} else {
43+
for (let i = 0; i < obj.input.length - 1; i++) {
44+
const node_equal = new onnx.NodeProto()
45+
node_equal.setOpType('Less')
46+
node_equal.addInput(obj.input[i])
47+
node_equal.addInput(obj.input[i + 1])
48+
node_equal.addOutput(`${obj.name}_lt_${i}`)
49+
graph.addNode(node_equal)
50+
}
51+
let prev_in = obj.name + '_lt_0'
52+
for (let i = 1; i < obj.input.length - 2; i++) {
53+
const node_mul = new onnx.NodeProto()
54+
node_mul.setOpType('And')
55+
node_mul.addInput(prev_in)
56+
node_mul.addInput(`${obj.name}_lt_${i}`)
57+
node_mul.addOutput((prev_in = obj.name + `_and_${i - 1}`))
58+
graph.addNode(node_mul)
59+
}
60+
61+
node.setOpType('And')
62+
node.addInput(prev_in)
63+
node.addInput(`${obj.name}_lt_${obj.input.length - 2}`)
2364
}
2465
node.addOutput(obj.name)
2566
graph.addNode(node)

lib/model/nns/onnx/layer/less_or_equal.js

+45-4
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,52 @@ export default {
1515
throw new Error(`Invalid attribute 'input' value ${obj.input}.`)
1616
}
1717
const graph = model.getGraph()
18-
1918
const node = new onnx.NodeProto()
20-
node.setOpType('LessOrEqual')
21-
for (const i of obj.input) {
22-
node.addInput(i)
19+
if (obj.input.length === 1) {
20+
const node_shape = new onnx.NodeProto()
21+
node_shape.setOpType('Shape')
22+
node_shape.addInput(obj.input[0])
23+
node_shape.addOutput(obj.name + '_shape')
24+
graph.addNode(node_shape)
25+
26+
node.setOpType('ConstantOfShape')
27+
node.addInput(obj.name + '_shape')
28+
29+
const tensor = new onnx.TensorProto()
30+
tensor.setDataType(onnx.TensorProto.DataType.BOOL)
31+
tensor.setDimsList([1])
32+
tensor.setInt32DataList([1])
33+
const attrValue = new onnx.AttributeProto()
34+
attrValue.setName('value')
35+
attrValue.setType(onnx.AttributeProto.AttributeType.TENSOR)
36+
attrValue.setT(tensor)
37+
node.addAttribute(attrValue)
38+
} else if (obj.input.length === 2) {
39+
node.setOpType('LessOrEqual')
40+
node.addInput(obj.input[0])
41+
node.addInput(obj.input[1])
42+
} else {
43+
for (let i = 0; i < obj.input.length - 1; i++) {
44+
const node_equal = new onnx.NodeProto()
45+
node_equal.setOpType('LessOrEqual')
46+
node_equal.addInput(obj.input[i])
47+
node_equal.addInput(obj.input[i + 1])
48+
node_equal.addOutput(`${obj.name}_le_${i}`)
49+
graph.addNode(node_equal)
50+
}
51+
let prev_in = obj.name + '_le_0'
52+
for (let i = 1; i < obj.input.length - 2; i++) {
53+
const node_mul = new onnx.NodeProto()
54+
node_mul.setOpType('And')
55+
node_mul.addInput(prev_in)
56+
node_mul.addInput(`${obj.name}_le_${i}`)
57+
node_mul.addOutput((prev_in = obj.name + `_and_${i - 1}`))
58+
graph.addNode(node_mul)
59+
}
60+
61+
node.setOpType('And')
62+
node.addInput(prev_in)
63+
node.addInput(`${obj.name}_le_${obj.input.length - 2}`)
2364
}
2465
node.addOutput(obj.name)
2566
graph.addNode(node)

tests/lib/model/nns/layer/equal.test.js

+13-7
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,22 @@ describe('layer', () => {
1111
})
1212

1313
describe('calc', () => {
14-
test('matrix', () => {
14+
test.each([2, 3])('matrix %p', n => {
1515
const layer = Layer.fromObject({ type: 'equal' })
1616

17-
const a = Matrix.randint(100, 10, -5, 5)
18-
const b = Matrix.randint(100, 10, -5, 5)
17+
const x = []
18+
for (let i = 0; i < n; i++) {
19+
x.push(Matrix.randint(100, 10, -2, 2))
20+
}
1921

20-
const y = layer.calc(a, b)
21-
for (let i = 0; i < a.rows; i++) {
22-
for (let j = 0; j < a.cols; j++) {
23-
expect(y.at(i, j)).toBe(a.at(i, j) === b.at(i, j))
22+
const y = layer.calc(...x)
23+
for (let i = 0; i < x[0].rows; i++) {
24+
for (let j = 0; j < x[0].cols; j++) {
25+
let f = true
26+
for (let k = 0; k < n - 1; k++) {
27+
f &&= x[k].at(i, j) === x[k + 1].at(i, j)
28+
}
29+
expect(y.at(i, j)).toBe(f)
2430
}
2531
}
2632
})

0 commit comments

Comments
 (0)