Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 754c032

Browse files
authoredMay 15, 2024··
Move model termination process in ai_manager (#848)
1 parent 120f346 commit 754c032

34 files changed

+374
-475
lines changed
 

‎js/controller.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ export default class Controller {
44
this._e = root ?? platform.setting.ml.configElement.node()
55
this._terminators = []
66

7-
platform.setting.terminate = this.terminate.bind(this)
7+
platform._manager._terminateFunction.push(this.terminate.bind(this))
88
this.input = this.input.bind(this)
99

1010
this.input.text = conf => {

‎js/manager.js

+7-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ export default class AIManager {
2323
this._preprocess = []
2424
this._preprocessnames = []
2525
this._modelname = ''
26+
this._terminateFunction = []
2627

2728
this._emitter = new EventEmitter()
2829
}
@@ -149,6 +150,8 @@ export default class AIManager {
149150
}
150151

151152
async setModel(model) {
153+
this._terminateFunction.forEach(t => t())
154+
this._terminateFunction = []
152155
this._modelname = model
153156

154157
if (!model) {
@@ -158,7 +161,10 @@ export default class AIManager {
158161
loadedModel[model] = obj.default
159162
}
160163
try {
161-
loadedModel[model](this.platform)
164+
const tf = loadedModel[model](this.platform)
165+
if (tf) {
166+
this._terminateFunction.push(tf)
167+
}
162168
} catch (e) {
163169
console.error(e)
164170
return e

‎js/model_selector.js

-7
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,6 @@ app.component('model-selector', {
604604
aiTask: AITask,
605605
aiPreprocess: AIPreprocess,
606606
modelFilter: '',
607-
terminateFunction: [],
608607
state: {},
609608
mlData: 'manual',
610609
mlTask: '',
@@ -613,9 +612,6 @@ app.component('model-selector', {
613612
isLoadParam: false,
614613
historyWillPush: false,
615614
settings: (_this => ({
616-
set terminate(value) {
617-
_this.terminateFunction.push(value)
618-
},
619615
rl: {
620616
get configElement() {
621617
return document.querySelector('#rl_menu')
@@ -1033,9 +1029,6 @@ app.component('model-selector', {
10331029
return title
10341030
},
10351031
ready() {
1036-
this.terminateFunction.forEach(t => t())
1037-
this.terminateFunction = []
1038-
10391032
const mlModel = this.mlModel
10401033
const mlelem = document.querySelector('#method_menu')
10411034
mlelem.querySelector('.buttons').replaceChildren()

‎js/view/a2c.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ export default function (platform) {
158158
elm.disabled = true
159159
}
160160

161-
platform.setting.terminate = () => {
161+
return () => {
162162
isRunning = false
163163
agent.terminate()
164164
}

‎js/view/agglomerative.js

+3-3
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,6 @@ const argmax = function (arr, key) {
3030
export default function (platform) {
3131
platform.setting.ml.usage =
3232
'Click and add data point. Next, select distance type and click "Initialize". Finally, select cluster number.'
33-
platform.setting.terminate = () => {
34-
document.querySelector('svg .grouping').remove()
35-
}
3633
const svg = platform.svg
3734
const line = p => {
3835
let s = ''
@@ -222,4 +219,7 @@ export default function (platform) {
222219
.on('input', () => {
223220
clusternumbeript.value = clusternumber.value
224221
})
222+
return () => {
223+
document.querySelector('svg .grouping').remove()
224+
}
225225
}

‎js/view/autoencoder.js

+63-133
Original file line numberDiff line numberDiff line change
@@ -25,102 +25,73 @@ class AutoencoderWorker extends BaseWorker {
2525
}
2626
}
2727

28-
var dispAEClt = function (elm, model, platform) {
29-
const step = 8
30-
31-
return async cb => {
32-
const iteration = +elm.select('[name=iteration]').property('value')
33-
const batch = +elm.select('[name=batch]').property('value')
34-
const rate = +elm.select('[name=rate]').property('value')
35-
const rho = +elm.select('[name=rho]').property('value')
36-
const fite = await model.fit(platform.trainInput, iteration, rate, batch, rho)
37-
platform.plotLoss(fite.loss)
38-
let p_mat = Matrix.fromArray(await model.reduce(platform.trainInput))
39-
40-
const t_mat = p_mat.argmax(1).value.map(v => v + 1)
41-
let tp_mat = Matrix.fromArray(await model.reduce(platform.testInput(step)))
42-
let categories = tp_mat.argmax(1)
43-
categories.add(1)
44-
platform.trainResult = t_mat
45-
platform.testResult(categories.value)
46-
47-
cb && cb(fite.epoch)
48-
}
49-
}
50-
51-
var dispAEad = function (elm, model, platform) {
52-
return async cb => {
53-
const iteration = +elm.select('[name=iteration]').property('value')
54-
const batch = +elm.select('[name=batch]').property('value')
55-
const rate = +elm.select('[name=rate]').property('value')
56-
const rho = +elm.select('[name=rho]').property('value')
57-
const threshold = +elm.select('[name=threshold]').property('value')
58-
59-
const tx = platform.trainInput
60-
const fite = await model.fit(tx, iteration, rate, batch, rho)
61-
platform.plotLoss(fite.loss)
62-
const px = platform.testInput(4)
63-
let pd = [].concat(tx, px)
64-
const e = await model.predict(pd)
65-
let pred = e.data.slice(0, tx.length)
66-
let pred_tile = e.data.slice(tx.length)
67-
let d = tx[0].length
68-
69-
const outliers = []
70-
for (let i = 0; i < pred.length; i++) {
71-
let v = 0
72-
for (let k = 0; k < d; k++) {
73-
v += (pred[i][k] - tx[i][k]) ** 2
74-
}
75-
outliers.push(v > threshold)
76-
}
77-
const outlier_tiles = []
78-
for (let i = 0; i < pred_tile.length; i++) {
79-
let v = 0
80-
for (let k = 0; k < d; k++) {
81-
v += (pred_tile[i][k] - px[i][k]) ** 2
82-
}
83-
outlier_tiles.push(v > threshold)
84-
}
85-
platform.trainResult = outliers
86-
platform.testResult(outlier_tiles)
87-
88-
cb && cb(fite.epoch)
89-
}
90-
}
91-
92-
var dispAEdr = function (elm, model, platform) {
93-
return async cb => {
94-
const iteration = +elm.select('[name=iteration]').property('value')
95-
const batch = +elm.select('[name=batch]').property('value')
96-
const rate = +elm.select('[name=rate]').property('value')
97-
const rho = +elm.select('[name=rho]').property('value')
98-
99-
const fite = await model.fit(platform.trainInput, iteration, rate, batch, rho)
100-
platform.plotLoss(fite.loss)
101-
platform.trainResult = await model.reduce(platform.trainInput)
102-
cb && cb(fite.epoch)
103-
}
104-
}
105-
106-
var dispAE = function (elm, platform) {
28+
export default function (platform) {
29+
platform.setting.ml.usage =
30+
'Click and add data point. Next, click "Initialize". Finally, click "Fit" button repeatedly.'
10731
const mode = platform.task
10832
const controller = new Controller(platform)
10933
const model = new AutoencoderWorker()
11034
let epoch = 0
111-
const fitModel =
112-
mode === 'AD'
113-
? dispAEad(elm, model, platform)
114-
: mode === 'CT'
115-
? dispAEClt(elm, model, platform)
116-
: dispAEdr(elm, model, platform)
35+
const fitModel = async cb => {
36+
if (mode === 'AD') {
37+
const tx = platform.trainInput
38+
const fite = await model.fit(tx, +iteration.value, rate.value, batch.value, rho.value)
39+
platform.plotLoss(fite.loss)
40+
const px = platform.testInput(4)
41+
let pd = [].concat(tx, px)
42+
const e = await model.predict(pd)
43+
let pred = e.data.slice(0, tx.length)
44+
let pred_tile = e.data.slice(tx.length)
45+
let d = tx[0].length
46+
47+
const outliers = []
48+
for (let i = 0; i < pred.length; i++) {
49+
let v = 0
50+
for (let k = 0; k < d; k++) {
51+
v += (pred[i][k] - tx[i][k]) ** 2
52+
}
53+
outliers.push(v > threshold.value)
54+
}
55+
const outlier_tiles = []
56+
for (let i = 0; i < pred_tile.length; i++) {
57+
let v = 0
58+
for (let k = 0; k < d; k++) {
59+
v += (pred_tile[i][k] - px[i][k]) ** 2
60+
}
61+
outlier_tiles.push(v > threshold.value)
62+
}
63+
platform.trainResult = outliers
64+
platform.testResult(outlier_tiles)
65+
66+
cb && cb(fite.epoch)
67+
} else if (mode === 'CT') {
68+
const step = 8
69+
const fite = await model.fit(platform.trainInput, +iteration.value, rate.value, batch.value, rho.value)
70+
platform.plotLoss(fite.loss)
71+
let p_mat = Matrix.fromArray(await model.reduce(platform.trainInput))
72+
73+
const t_mat = p_mat.argmax(1).value.map(v => v + 1)
74+
let tp_mat = Matrix.fromArray(await model.reduce(platform.testInput(step)))
75+
let categories = tp_mat.argmax(1)
76+
categories.add(1)
77+
platform.trainResult = t_mat
78+
platform.testResult(categories.value)
79+
80+
cb && cb(fite.epoch)
81+
} else {
82+
const fite = await model.fit(platform.trainInput, +iteration.value, rate.value, batch.value, rho.value)
83+
platform.plotLoss(fite.loss)
84+
platform.trainResult = await model.reduce(platform.trainInput)
85+
cb && cb(fite.epoch)
86+
}
87+
}
11788

11889
let rdim = null
11990
if (mode !== 'DR') {
12091
rdim = controller.input.number({ label: ' Size ', min: 1, max: 100, value: 10 })
12192
}
12293
const builder = new NeuralNetworkBuilder()
123-
builder.makeHtml(elm, { optimizer: true })
94+
builder.makeHtml(platform.setting.ml.configElement, { optimizer: true })
12495
const slbConf = controller.stepLoopButtons().init(done => {
12596
platform.init()
12697
if (platform.datas.length === 0) {
@@ -131,48 +102,13 @@ var dispAE = function (elm, platform) {
131102

132103
model.initialize(platform.datas.dimension, rd, builder.layers, builder.invlayers, builder.optimizer).then(done)
133104
})
134-
elm.append('span').text(' Iteration ')
135-
elm.append('select')
136-
.attr('name', 'iteration')
137-
.selectAll('option')
138-
.data([1, 10, 100, 1000, 10000])
139-
.enter()
140-
.append('option')
141-
.property('value', d => d)
142-
.text(d => d)
143-
elm.append('span').text(' Learning rate ')
144-
elm.append('input')
145-
.attr('type', 'number')
146-
.attr('name', 'rate')
147-
.attr('min', 0)
148-
.attr('max', 100)
149-
.attr('step', 0.01)
150-
.attr('value', 0.001)
151-
elm.append('span').text(' Batch size ')
152-
elm.append('input')
153-
.attr('type', 'number')
154-
.attr('name', 'batch')
155-
.attr('value', 10)
156-
.attr('min', 1)
157-
.attr('max', 100)
158-
.attr('step', 1)
159-
elm.append('span').text(' Sparse rho ')
160-
elm.append('input')
161-
.attr('type', 'number')
162-
.attr('name', 'rho')
163-
.attr('value', 0.02)
164-
.attr('min', 0)
165-
.attr('max', 1)
166-
.attr('step', 0.01)
105+
const iteration = controller.select({ label: ' Iteration ', values: [1, 10, 100, 1000, 10000] })
106+
const rate = controller.input.number({ label: ' Learning rate ', min: 0, max: 100, step: 0.01, value: 0.001 })
107+
const batch = controller.input.number({ label: ' Batch size ', min: 1, max: 100, value: 10 })
108+
const rho = controller.input.number({ label: ' Sparse rho ', min: 0, max: 1, step: 0.01, value: 0.02 })
109+
let threshold = null
167110
if (mode === 'AD') {
168-
elm.append('span').text(' threshold = ')
169-
elm.append('input')
170-
.attr('type', 'number')
171-
.attr('name', 'threshold')
172-
.attr('value', 0.02)
173-
.attr('min', 0)
174-
.attr('max', 10)
175-
.attr('step', 0.01)
111+
threshold = controller.input.number({ label: ' threshold = ', min: 0, max: 10, step: 0.01, value: 0.02 })
176112
}
177113
slbConf
178114
.step(cb => {
@@ -187,9 +123,3 @@ var dispAE = function (elm, platform) {
187123
model.terminate()
188124
}
189125
}
190-
191-
export default function (platform) {
192-
platform.setting.ml.usage =
193-
'Click and add data point. Next, click "Initialize". Finally, click "Fit" button repeatedly.'
194-
platform.setting.terminate = dispAE(platform.setting.ml.configElement, platform)
195-
}

‎js/view/birch.js

+12-36
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,25 @@
11
import BIRCH from '../../lib/model/birch.js'
2+
import Controller from '../controller.js'
23

3-
var dispBIRCH = function (elm, platform) {
4+
export default function (platform) {
5+
platform.setting.ml.usage = 'Click and add data point. Then, click "Fit" button.'
46
platform.setting.ml.reference = {
57
title: 'BIRCH (Wikipedia)',
68
url: 'https://en.wikipedia.org/wiki/BIRCH',
79
}
10+
const controller = new Controller(platform)
811
const fitModel = () => {
9-
const b = +elm.select('[name=b]').property('value')
10-
const t = +elm.select('[name=t]').property('value')
11-
const l = +elm.select('[name=l]').property('value')
12-
const model = new BIRCH(null, b, t, l)
12+
const model = new BIRCH(null, b.value, t.value, l.value)
1313
model.fit(platform.trainInput)
1414
const pred = model.predict(platform.trainInput)
1515
platform.trainResult = pred.map(v => v + 1)
16-
elm.select('[name=clusters]').text(new Set(pred).size)
16+
clusters.value = new Set(pred).size
1717
}
1818

19-
elm.append('span').text(' b ')
20-
elm.append('input').attr('type', 'number').attr('name', 'b').attr('min', 2).attr('max', 1000).attr('value', 10)
21-
elm.append('span').text(' t ')
22-
elm.append('input')
23-
.attr('type', 'number')
24-
.attr('name', 't')
25-
.attr('min', 0.01)
26-
.attr('max', 10)
27-
.attr('step', 0.01)
28-
.attr('value', 0.2)
29-
elm.append('span').text(' l ')
30-
elm.append('input').attr('type', 'number').attr('name', 'l').attr('min', 2).attr('max', 10000).attr('value', 10000)
31-
elm.append('span').text(' sub algorithm ')
32-
elm.append('select')
33-
.attr('name', 'subalgo')
34-
.selectAll('option')
35-
.data(['none'])
36-
.enter()
37-
.append('option')
38-
.attr('value', d => d)
39-
.text(d => d)
40-
const stepButton = elm.append('input').attr('type', 'button').attr('value', 'Fit').on('click', fitModel)
41-
elm.append('span').text(' Clusters: ')
42-
elm.append('span').attr('name', 'clusters')
43-
return () => {}
44-
}
45-
46-
export default function (platform) {
47-
platform.setting.ml.usage = 'Click and add data point. Then, click "Fit" button.'
48-
platform.setting.terminate = dispBIRCH(platform.setting.ml.configElement, platform)
19+
const b = controller.input.number({ label: ' b ', min: 2, max: 1000, value: 10 })
20+
const t = controller.input.number({ label: ' t ', min: 0.01, max: 10, step: 0.01, value: 0.2 })
21+
const l = controller.input.number({ label: ' l ', min: 2, max: 10000, value: 10000 })
22+
const subalgo = controller.select(['none'])
23+
controller.input.button('Fit').on('click', fitModel)
24+
const clusters = controller.text({ label: ' Clusters: ' })
4925
}

‎js/view/dbscan.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ export default function (platform) {
6666
const minpts = controller.input.number({ label: 'min pts', min: 2, max: 1000, value: 5 }).on('change', fitModel)
6767
controller.input.button('Fit').on('click', fitModel)
6868
const clusters = controller.text({ label: ' Clusters: ' })
69-
platform.setting.terminate = () => {
69+
return () => {
7070
range.remove()
7171
}
7272
}

‎js/view/decision_tree.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ export default function (platform) {
149149
const depth = controller.text('0')
150150
controller.text(' depth ')
151151

152-
platform.setting.terminate = () => {
152+
return () => {
153153
plotter.remove()
154154
}
155155
}

0 commit comments

Comments
 (0)
Please sign in to comment.