Skip to content

Commit 68f50ea

Browse files
committed
Refactor model initialization to accept configuration objects for affinity parameters
1 parent 3cc16aa commit 68f50ea

12 files changed

+94
-95
lines changed

Diff for: js/view/label_propagation.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ export default function (platform) {
77
let model = null
88
const fitModel = () => {
99
if (!model) {
10-
model = new LabelPropagation(method.value, sigma.value, k.value)
10+
model = new LabelPropagation({ name: method.value, sigma: sigma.value, k: k.value })
1111
model.init(
1212
platform.trainInput,
1313
platform.trainOutput.map(v => v[0])

Diff for: js/view/label_spreading.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ export default function (platform) {
77
let model = null
88
const fitModel = () => {
99
if (!model) {
10-
model = new LabelSpreading(alpha.value, method.value, sigma.value, k.value)
10+
model = new LabelSpreading(alpha.value, { name: method.value, sigma: sigma.value, k: k.value })
1111
model.init(
1212
platform.trainInput,
1313
platform.trainOutput.map(v => v[0])

Diff for: js/view/laplacian_eigenmaps.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ export default function (platform) {
1616
const k = controller.input.number({ label: 'k =', name: 'k_nearest', min: 1, max: 100, value: 10 })
1717
controller.input.button('Fit').on('click', () => {
1818
const dim = platform.dimension
19-
const model = new LaplacianEigenmaps(dim, method.value, k.value, sigma.value)
19+
const model = new LaplacianEigenmaps(dim, { name: method.value, k: k.value, sigma: sigma.value })
2020
const pred = model.predict(platform.trainInput)
2121
platform.trainResult = pred
2222
})

Diff for: js/view/spectral.js

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ export default function (platform) {
2929
knnSpan.element.style.display = 'none'
3030

3131
const slbConf = controller.stepLoopButtons().init(() => {
32-
const param = { sigma: sigma.value, k: k.value }
33-
model = new SpectralClustering(method.value, param)
32+
const param = { name: method.value, sigma: sigma.value, k: k.value }
33+
model = new SpectralClustering(param)
3434
model.init(platform.trainInput)
3535
clusters.value = model.size
3636
runSpan.element.querySelectorAll('input').forEach(elm => (elm.disabled = null))

Diff for: lib/model/label_propagation.js

+15-13
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@ export default class LabelPropagation {
99
// http://yamaguchiyuto.hatenablog.com/entry/2016/09/22/014202
1010
// https://github.com/scikit-learn/scikit-learn/blob/15a949460/sklearn/semi_supervised/_label_propagation.py
1111
/**
12-
* @param {'rbf' | 'knn'} [method] Method name
13-
* @param {number} [sigma] Sigma of normal distribution
14-
* @param {number} [k] Number of neighborhoods
12+
* @param {'rbf' | 'knn' | { name: 'rbf', sigma?: number, k?: number } | { name: 'knn', k?: number }} [method] Method name
1513
*/
16-
constructor(method = 'rbf', sigma = 0.1, k = Infinity) {
17-
this._k = k
18-
this._sigma = sigma
19-
this._affinity = method
14+
constructor(method = 'rbf') {
15+
if (typeof method === 'string') {
16+
this._affinity = { name: method }
17+
} else {
18+
this._affinity = method
19+
}
2020
}
2121

2222
_affinity_matrix(x) {
@@ -31,23 +31,25 @@ export default class LabelPropagation {
3131
}
3232

3333
const con = Matrix.zeros(n, n)
34-
if (this._k >= n) {
34+
const k = this._affinity.k ?? Infinity
35+
if (k >= n) {
3536
con.fill(1)
36-
} else if (this._k > 0) {
37+
} else if (k > 0) {
3738
for (let i = 0; i < n; i++) {
3839
const di = distances.row(i).value.map((v, i) => [v, i])
3940
di.sort((a, b) => a[0] - b[0])
40-
for (let j = 1; j < Math.min(this._k + 1, di.length); j++) {
41+
for (let j = 1; j < Math.min(k + 1, di.length); j++) {
4142
con.set(i, di[j][1], 1)
4243
}
4344
}
4445
con.add(con.t)
4546
con.div(2)
4647
}
4748

48-
if (this._affinity === 'rbf') {
49-
return Matrix.map(distances, (v, i) => (con.at(i) > 0 ? Math.exp(-(v ** 2) / this._sigma ** 2) : 0))
50-
} else if (this._affinity === 'knn') {
49+
if (this._affinity.name === 'rbf') {
50+
const sigma = this._affinity.sigma ?? 0.1
51+
return Matrix.map(distances, (v, i) => (con.at(i) > 0 ? Math.exp(-(v ** 2) / sigma ** 2) : 0))
52+
} else if (this._affinity.name === 'knn') {
5153
return Matrix.map(con, v => (v > 0 ? 1 : 0))
5254
}
5355
}

Diff for: lib/model/label_spreading.js

+15-13
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@ export default class LabelSpreading {
88
// https://github.com/scikit-learn/scikit-learn/blob/15a949460/sklearn/semi_supervised/_label_propagation.py
99
/**
1010
* @param {number} [alpha] Clamping factor
11-
* @param {'rbf' | 'knn'} [method] Method name
12-
* @param {number} [sigma] Sigma of normal distribution
13-
* @param {number} [k] Number of neighborhoods
11+
* @param {'rbf' | 'knn' | { name: 'rbf', sigma?: number, k?: number } | { name: 'knn', k?: number }} [method] Method name
1412
*/
15-
constructor(alpha = 0.2, method = 'rbf', sigma = 0.1, k = Infinity) {
16-
this._k = k
17-
this._sigma = sigma
18-
this._affinity = method
13+
constructor(alpha = 0.2, method = 'rbf') {
14+
if (typeof method === 'string') {
15+
this._affinity = { name: method }
16+
} else {
17+
this._affinity = method
18+
}
1919

2020
this._alpha = alpha
2121
}
@@ -32,23 +32,25 @@ export default class LabelSpreading {
3232
}
3333

3434
const con = Matrix.zeros(n, n)
35-
if (this._k >= n) {
35+
const k = this._affinity.k ?? Infinity
36+
if (k >= n) {
3637
con.fill(1)
37-
} else if (this._k > 0) {
38+
} else if (k > 0) {
3839
for (let i = 0; i < n; i++) {
3940
const di = distances.row(i).value.map((v, i) => [v, i])
4041
di.sort((a, b) => a[0] - b[0])
41-
for (let j = 1; j < Math.min(this._k + 1, di.length); j++) {
42+
for (let j = 1; j < Math.min(k + 1, di.length); j++) {
4243
con.set(i, di[j][1], 1)
4344
}
4445
}
4546
con.add(con.t)
4647
con.div(2)
4748
}
4849

49-
if (this._affinity === 'rbf') {
50-
return Matrix.map(distances, (v, i) => (con.at(i) > 0 ? Math.exp(-(v ** 2) / this._sigma ** 2) : 0))
51-
} else if (this._affinity === 'knn') {
50+
if (this._affinity.name === 'rbf') {
51+
const sigma = this._affinity.sigma ?? 0.1
52+
return Matrix.map(distances, (v, i) => (con.at(i) > 0 ? Math.exp(-(v ** 2) / sigma ** 2) : 0))
53+
} else if (this._affinity.name === 'knn') {
5254
return Matrix.map(con, v => (v > 0 ? 1 : 0))
5355
}
5456
}

Diff for: lib/model/laplacian_eigenmaps.js

+14-12
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,16 @@ export default class LaplacianEigenmaps {
99
// https://scikit-learn.org/stable/modules/generated/sklearn.manifold.SpectralEmbedding.html
1010
/**
1111
* @param {number} rd Reduced dimension
12-
* @param {'rbf' | 'knn'} [affinity] Affinity type name
13-
* @param {number} [k] Number of neighborhoods
14-
* @param {number} [sigma] Sigma of normal distribution
12+
* @param {'rbf' | 'knn' | { name: 'rbf', sigma?: number, k?: number } | { name: 'knn', k?: number }} [affinity] Affinity type name
1513
* @param {'unnormalized' | 'normalized'} [laplacian] Normalized laplacian matrix or not
1614
*/
17-
constructor(rd, affinity = 'rbf', k = 10, sigma = 1, laplacian = 'unnormalized') {
15+
constructor(rd, affinity = 'rbf', laplacian = 'unnormalized') {
1816
this._rd = rd
19-
this._affinity = affinity
20-
this._k = k
21-
this._sigma = sigma
17+
if (typeof affinity === 'string') {
18+
this._affinity = { name: affinity }
19+
} else {
20+
this._affinity = affinity
21+
}
2222
this._laplacian = laplacian
2323
}
2424

@@ -41,11 +41,12 @@ export default class LaplacianEigenmaps {
4141
}
4242

4343
const con = Matrix.zeros(n, n)
44-
if (this._k > 0) {
44+
const k = this._affinity.k ?? 10
45+
if (k > 0) {
4546
for (let i = 0; i < n; i++) {
4647
const di = distances.row(i).value.map((v, i) => [v, i])
4748
di.sort((a, b) => a[0] - b[0])
48-
for (let j = 1; j < Math.min(this._k + 1, di.length); j++) {
49+
for (let j = 1; j < Math.min(k + 1, di.length); j++) {
4950
con.set(i, di[j][1], 1)
5051
}
5152
}
@@ -54,9 +55,10 @@ export default class LaplacianEigenmaps {
5455
}
5556

5657
let W
57-
if (this._affinity === 'rbf') {
58-
W = Matrix.map(distances, (v, i) => (con.at(i) > 0 ? Math.exp(-(v ** 2) / this._sigma ** 2) : 0))
59-
} else if (this._affinity === 'knn') {
58+
if (this._affinity.name === 'rbf') {
59+
const sigma = this._affinity.sigma ?? 1
60+
W = Matrix.map(distances, (v, i) => (con.at(i) > 0 ? Math.exp(-(v ** 2) / sigma ** 2) : 0))
61+
} else if (this._affinity.name === 'knn') {
6062
W = Matrix.map(con, v => (v > 0 ? 1 : 0))
6163
}
6264
let d = W.sum(1).value

Diff for: lib/model/spectral.js

+3-8
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,13 @@ import LaplacianEigenmaps from './laplacian_eigenmaps.js'
77
export default class SpectralClustering {
88
// https://mr-r-i-c-e.hatenadiary.org/entry/20121214/1355499195
99
/**
10-
* @param {'rbf' | 'knn'} [affinity] Affinity type name
11-
* @param {object} [param] Config
12-
* @param {number} [param.sigma] Sigma of normal distribution
13-
* @param {number} [param.k] Number of neighborhoods
10+
* @param {'rbf' | 'knn' | { name: 'rbf', sigma?: number, k?: number } | { name: 'knn', k?: number }} [affinity] Affinity type name
1411
*/
15-
constructor(affinity = 'rbf', param = {}) {
12+
constructor(affinity = 'rbf') {
1613
this._size = 0
1714
this._epoch = 0
1815
this._clustering = new KMeanspp()
1916
this._affinity = affinity
20-
this._sigma = param.sigma || 1.0
21-
this._k = param.k || 10
2217
}
2318

2419
/**
@@ -44,7 +39,7 @@ export default class SpectralClustering {
4439
init(datas) {
4540
const n = datas.length
4641
this._n = n
47-
const le = new LaplacianEigenmaps(datas[0].length, this._affinity, this._k, this._sigma, 'normalized')
42+
const le = new LaplacianEigenmaps(datas[0].length, this._affinity, 'normalized')
4843
this.ready = false
4944
le.predict(datas)
5045
this._ev = le._ev

Diff for: tests/lib/model/label_propagation.test.js

+17-20
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,22 @@ import LabelPropagation from '../../../lib/model/label_propagation.js'
33

44
import { accuracy } from '../../../lib/evaluate/classification.js'
55

6-
test.each([{}, { method: 'rbf', sigma: 0.2 }, { method: 'knn', k: 10 }])(
7-
'semi-classifier %p',
8-
({ method, sigma, k }) => {
9-
const model = new LabelPropagation(method, sigma, k)
10-
const x = Matrix.concat(Matrix.randn(50, 2, 0, 0.2), Matrix.randn(50, 2, 5, 0.2)).toArray()
11-
const t = []
12-
const t_org = []
13-
for (let i = 0; i < x.length; i++) {
14-
t_org[i] = t[i] = String.fromCharCode('a'.charCodeAt(0) + Math.floor(i / 50))
15-
if (Math.random() < 0.5) {
16-
t[i] = null
17-
}
6+
test.each([undefined, 'rbf', { name: 'rbf', sigma: 0.2 }, { name: 'knn', k: 10 }])('semi-classifier %p', method => {
7+
const model = new LabelPropagation(method)
8+
const x = Matrix.concat(Matrix.randn(50, 2, 0, 0.2), Matrix.randn(50, 2, 5, 0.2)).toArray()
9+
const t = []
10+
const t_org = []
11+
for (let i = 0; i < x.length; i++) {
12+
t_org[i] = t[i] = String.fromCharCode('a'.charCodeAt(0) + Math.floor(i / 50))
13+
if (Math.random() < 0.5) {
14+
t[i] = null
1815
}
19-
model.init(x, t)
20-
for (let i = 0; i < 20; i++) {
21-
model.fit()
22-
}
23-
const y = model.predict(x)
24-
const acc = accuracy(y, t_org)
25-
expect(acc).toBeGreaterThan(0.95)
2616
}
27-
)
17+
model.init(x, t)
18+
for (let i = 0; i < 20; i++) {
19+
model.fit()
20+
}
21+
const y = model.predict(x)
22+
const acc = accuracy(y, t_org)
23+
expect(acc).toBeGreaterThan(0.95)
24+
})

Diff for: tests/lib/model/label_spreading.test.js

+21-20
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,26 @@ import LabelSpreading from '../../../lib/model/label_spreading.js'
33

44
import { accuracy } from '../../../lib/evaluate/classification.js'
55

6-
test.each([{}, { alpha: 0.5, method: 'rbf', sigma: 0.2 }, { alpha: 0.8, method: 'knn', k: 10 }])(
7-
'semi-classifier %s %p',
8-
({ alpha, method, sigma, k }) => {
9-
const model = new LabelSpreading(alpha, method, sigma, k)
10-
const x = Matrix.concat(Matrix.randn(50, 2, 0, 0.2), Matrix.randn(50, 2, 5, 0.2)).toArray()
11-
const t = []
12-
const t_org = []
13-
for (let i = 0; i < x.length; i++) {
14-
t_org[i] = t[i] = String.fromCharCode('a'.charCodeAt(0) + Math.floor(i / 50))
15-
if (Math.random() < 0.5) {
16-
t[i] = null
17-
}
6+
test.each([
7+
[undefined, undefined],
8+
[0.5, { name: 'rbf', sigma: 0.2 }],
9+
[0.8, { name: 'knn', k: 10 }],
10+
])('semi-classifier %p %p', (alpha, method) => {
11+
const model = new LabelSpreading(alpha, method)
12+
const x = Matrix.concat(Matrix.randn(50, 2, 0, 0.2), Matrix.randn(50, 2, 5, 0.2)).toArray()
13+
const t = []
14+
const t_org = []
15+
for (let i = 0; i < x.length; i++) {
16+
t_org[i] = t[i] = String.fromCharCode('a'.charCodeAt(0) + Math.floor(i / 50))
17+
if (Math.random() < 0.5) {
18+
t[i] = null
1819
}
19-
model.init(x, t)
20-
for (let i = 0; i < 20; i++) {
21-
model.fit()
22-
}
23-
const y = model.predict(x)
24-
const acc = accuracy(y, t_org)
25-
expect(acc).toBeGreaterThan(0.95)
2620
}
27-
)
21+
model.init(x, t)
22+
for (let i = 0; i < 20; i++) {
23+
model.fit()
24+
}
25+
const y = model.predict(x)
26+
const acc = accuracy(y, t_org)
27+
expect(acc).toBeGreaterThan(0.95)
28+
})

Diff for: tests/lib/model/laplacian_eigenmaps.test.js

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ import LaplacianEigenmaps from '../../../lib/model/laplacian_eigenmaps.js'
66

77
import { coRankingMatrix } from '../../../lib/evaluate/dimensionality_reduction.js'
88

9-
describe.each([undefined, 'knn'])('dimensionality reduction affinity:%p', affinity => {
9+
describe.each([undefined, 'knn', { name: 'rbf' }])('dimensionality reduction affinity:%p', affinity => {
1010
test.each([undefined, 'normalized'])('laplacian: %p', laplacian => {
1111
const x = Matrix.concat(Matrix.randn(30, 5, 0, 0.2), Matrix.randn(30, 5, 5, 0.2)).toArray()
12-
const model = new LaplacianEigenmaps(2, affinity, undefined, undefined, laplacian)
12+
const model = new LaplacianEigenmaps(2, affinity, laplacian)
1313

1414
const y = model.predict(x)
1515
expect(y[0]).toHaveLength(2)

Diff for: tests/lib/model/spectral.test.js

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ import SpectralClustering from '../../../lib/model/spectral.js'
33

44
import { randIndex } from '../../../lib/evaluate/clustering.js'
55

6-
test('clustering', () => {
7-
const model = new SpectralClustering()
6+
test.each([undefined, 'rbf', { name: 'rbf', sigma: 0.5 }, { name: 'knn', k: 4 }])('clustering %p', affinity => {
7+
const model = new SpectralClustering(affinity)
88
const n = 5
99
const x = Matrix.concat(
1010
Matrix.concat(Matrix.randn(n, 2, 0, 0.1), Matrix.randn(n, 2, 5, 0.1)),

0 commit comments

Comments
 (0)