Skip to content

Commit 56d7368

Browse files
authored
Add LRN layer support for ONNX export (#936)
1 parent b4d4617 commit 56d7368

File tree

3 files changed

+164
-0
lines changed

3 files changed

+164
-0
lines changed

lib/model/nns/onnx/layer/index.js

+1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ export { default as log_softmax } from './log_softmax.js'
7777
export { default as loglog } from './loglog.js'
7878
export { default as logsigmoid } from './logsigmoid.js'
7979
export { default as lp_pool } from './lp_pool.js'
80+
export { default as lrn } from './lrn.js'
8081
export { default as matmul } from './matmul.js'
8182
export { default as max } from './max.js'
8283
export { default as max_pool } from './max_pool.js'

lib/model/nns/onnx/layer/lrn.js

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import { onnx } from '../onnx_exporter.js'
2+
3+
/**
4+
* Handle lrn layer
5+
*/
6+
export default {
7+
/**
8+
* Export to onnx object.
9+
* @param {onnx.ModelProto} model Model object
10+
* @param {import("../../graph.js").LayerObject & {type: 'lrn'}} obj Node object
11+
* @param {{[key: string]: {type: onnx.TensorProto.DataType; size: number[]}}} info Output informatino of other layers
12+
*/
13+
export(model, obj, info) {
14+
const graph = model.getGraph()
15+
16+
const input = Array.isArray(obj.input) ? obj.input[0] : obj.input
17+
const size = info[input].size.concat()
18+
19+
const node = new onnx.NodeProto()
20+
node.setOpType('LRN')
21+
if (obj.channel_dim === 1) {
22+
node.addInput(input)
23+
node.addOutput(obj.name)
24+
} else if (obj.channel_dim == null || obj.channel_dim === -1) {
25+
const node_transpose1 = new onnx.NodeProto()
26+
node_transpose1.setOpType('Transpose')
27+
node_transpose1.addInput(input)
28+
node_transpose1.addOutput(obj.name + '_t1')
29+
const attrPerm1 = new onnx.AttributeProto()
30+
attrPerm1.setName('perm')
31+
attrPerm1.setType(onnx.AttributeProto.AttributeType.INTS)
32+
const perm1 = Array.from(size, (_, i) => i - 1)
33+
perm1[0] = 0
34+
perm1[1] = size.length - 1
35+
attrPerm1.setIntsList(perm1)
36+
node_transpose1.addAttribute(attrPerm1)
37+
graph.addNode(node_transpose1)
38+
39+
node.addInput(obj.name + '_t1')
40+
node.addOutput(obj.name + '_gap')
41+
42+
const node_transpose2 = new onnx.NodeProto()
43+
node_transpose2.setOpType('Transpose')
44+
node_transpose2.addInput(obj.name + '_gap')
45+
node_transpose2.addOutput(obj.name)
46+
const attrPerm2 = new onnx.AttributeProto()
47+
attrPerm2.setName('perm')
48+
attrPerm2.setType(onnx.AttributeProto.AttributeType.INTS)
49+
const perm2 = Array.from(size, (_, i) => i + 1)
50+
perm2[0] = 0
51+
perm2[perm2.length - 1] = 1
52+
attrPerm2.setIntsList(perm2)
53+
node_transpose2.addAttribute(attrPerm2)
54+
graph.addNode(node_transpose2)
55+
} else {
56+
throw new Error(`Not implemented value of attribute 'channel_dim' ${obj.channel_dim}.`)
57+
}
58+
59+
if (obj.n == null) {
60+
throw new Error("Require attribute 'n'")
61+
}
62+
const attrSize = new onnx.AttributeProto()
63+
attrSize.setName('size')
64+
attrSize.setType(onnx.AttributeProto.AttributeType.INT)
65+
attrSize.setI(obj.n)
66+
node.addAttribute(attrSize)
67+
68+
const attrAlpha = new onnx.AttributeProto()
69+
attrAlpha.setName('alpha')
70+
attrAlpha.setType(onnx.AttributeProto.AttributeType.FLOAT)
71+
attrAlpha.setF(obj.alpha ?? 0.0001)
72+
node.addAttribute(attrAlpha)
73+
const attrBeta = new onnx.AttributeProto()
74+
attrBeta.setName('beta')
75+
attrBeta.setType(onnx.AttributeProto.AttributeType.FLOAT)
76+
attrBeta.setF(obj.beta ?? 0.75)
77+
node.addAttribute(attrBeta)
78+
const attrBias = new onnx.AttributeProto()
79+
attrBias.setName('bias')
80+
attrBias.setType(onnx.AttributeProto.AttributeType.FLOAT)
81+
attrBias.setF(obj.k ?? 2)
82+
node.addAttribute(attrBias)
83+
84+
graph.addNode(node)
85+
},
86+
}
+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import { jest } from '@jest/globals'
2+
jest.retryTimes(3)
3+
4+
import * as ort from 'onnxruntime-web'
5+
ort.env.wasm.numThreads = 1
6+
7+
import ONNXExporter from '../../../../../../lib/model/nns/onnx/onnx_exporter.js'
8+
import lrn from '../../../../../../lib/model/nns/onnx/layer/lrn.js'
9+
import LRNLayer from '../../../../../../lib/model/nns/layer/lrn.js'
10+
import Tensor from '../../../../../../lib/util/tensor.js'
11+
12+
describe('export', () => {
13+
test.each([
14+
{ input: 'x', channel_dim: -1, n: 3 },
15+
{ input: ['x'], n: 3 },
16+
])('last channel %p', param => {
17+
const model = ONNXExporter.createONNXModel()
18+
lrn.export(model, { type: 'lrn', ...param }, { x: { size: [null, 10, 3] } })
19+
const nodes = model.getGraph().getNodeList()
20+
expect(nodes).toHaveLength(3)
21+
expect(nodes[0].getOpType()).toBe('Transpose')
22+
expect(nodes[1].getOpType()).toBe('Transpose')
23+
expect(nodes[2].getOpType()).toBe('LRN')
24+
})
25+
26+
test('first channel', () => {
27+
const model = ONNXExporter.createONNXModel()
28+
lrn.export(model, { type: 'lrn', input: 'x', channel_dim: 1, n: 3 }, { x: { size: [null, 10, 3] } })
29+
const nodes = model.getGraph().getNodeList()
30+
expect(nodes).toHaveLength(1)
31+
expect(nodes[0].getOpType()).toBe('LRN')
32+
})
33+
34+
test('invalid channel dim', () => {
35+
const model = ONNXExporter.createONNXModel()
36+
expect(() =>
37+
lrn.export(model, { type: 'lrn', input: ['x'], channel_dim: 0, n: 3 }, { x: { size: [null, 10, 3] } })
38+
).toThrow("Not implemented value of attribute 'channel_dim' 0")
39+
})
40+
41+
test('require n', () => {
42+
const model = ONNXExporter.createONNXModel()
43+
expect(() =>
44+
lrn.export(model, { type: 'lrn', input: ['x'], channel_dim: -1 }, { x: { size: [null, 10, 3] } })
45+
).toThrow("Require attribute 'n'")
46+
})
47+
})
48+
49+
describe('runtime', () => {
50+
let session
51+
afterEach(async () => {
52+
await session?.release()
53+
session = null
54+
})
55+
56+
test.each([
57+
[{ channel_dim: 1, n: 3 }, [null, 4, 3, 3], [1, 4, 3, 3]],
58+
[{ n: 5 }, [null, 4, 4, 10], [1, 4, 4, 10]],
59+
[{ alpha: 0.0002, beta: 0.7, k: 2, n: 5 }, [null, 3, 3, 5], [1, 3, 3, 5]],
60+
])('lrn %p %p %p', async (param, inSize, actualSize) => {
61+
const buf = ONNXExporter.dump([{ type: 'input', size: inSize }, { type: 'lrn', ...param }, { type: 'output' }])
62+
session = await ort.InferenceSession.create(buf)
63+
64+
const x = Tensor.randn(actualSize)
65+
const xten = new ort.Tensor('float32', x.value, x.sizes)
66+
const out = await session.run({ _input: xten })
67+
const yten = out._lrn
68+
expect(yten.dims).toEqual(actualSize)
69+
const y = await yten.getData(true)
70+
71+
const t = new LRNLayer(param).calc(x)
72+
expect(yten.dims).toEqual(t.sizes)
73+
for (let i = 0; i < y.length; i++) {
74+
expect(y[i]).toBeCloseTo(t.value[i])
75+
}
76+
})
77+
})

0 commit comments

Comments
 (0)