Skip to content

Commit e096a57

Browse files
authored
Add MARS (#897)
1 parent 2bea301 commit e096a57

File tree

6 files changed

+219
-1
lines changed

6 files changed

+219
-1
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ for (let i = 0; i < n; i++) {
124124
| clustering | (Soft / Kernel / Genetic / Weighted / Bisecting) k-means, k-means++, k-medois, k-medians, x-means, G-means, LBG, ISODATA, Fuzzy c-means, Possibilistic c-means, k-harmonic means, MacQueen, Hartigan-Wong, Elkan, Hamelry, Drake, Yinyang, Agglomerative (complete linkage, single linkage, group average, Ward's, centroid, weighted average, median), DIANA, Monothetic, Mutual kNN, Mean shift, DBSCAN, OPTICS, DTSCAN, HDBSCAN, DENCLUE, DBCLASD, BRIDGE, CLUES, PAM, CLARA, CLARANS, BIRCH, CURE, ROCK, C2P, PLSA, Latent dirichlet allocation, GMM, VBGMM, Affinity propagation, Spectral clustering, Mountain, (Growing) SOM, GTM, (Growing) Neural gas, Growing cell structures, LVQ, ART, SVC, CAST, CHAMELEON, COLL, CLIQUE, PROCLUS, ORCLUS, FINDIT, DOC, FastDOC, DiSH, NMF, Autoencoder |
125125
| classification | (Fisher's) Linear discriminant, Quadratic discriminant, Mixture discriminant, Least squares, (Multiclass / Kernel) Ridge, (Complement / Negation / Universal-set / Selective) Naive Bayes (gaussian), AODE, (Fuzzy / Weighted) k-nearest neighbor, Radius neighbor, Nearest centroid, ENN, ENaN, NNBCA, ADAMENN, DANN, IKNN, Decision tree, Random forest, Extra trees, GBDT, XGBoost, ALMA, (Aggressive) ROMMA, (Bounded) Online gradient descent, (Budgeted online) Passive aggressive, RLS, (Selective-sampling) Second order perceptron, AROW, NAROW, Confidence weighted, CELLIP, IELLIP, Normal herd, Stoptron, (Kernelized) Pegasos, MIRA, Forgetron, Projectron, Projectron++, Banditron, Ballseptron, (Multiclass) BSGD, ILK, SILK, (Multinomial) Logistic regression, (Multinomial) Probit, SVM, Gaussian process, HMM, CRF, Bayesian Network, LVQ, (Average / Multiclass / Voted / Kernelized / Selective-sampling / Margin / Shifting / Budget / Tighter / Tightest) Perceptron, PAUM, RBP, ADALINE, MADALINE, MLP, ELM, LMNN |
126126
| semi-supervised classification | k-nearest neighbor, Radius neighbor, Label propagation, Label spreading, k-means, GMM, S3VM, Ladder network |
127-
| regression | Least squares, Ridge, Lasso, Elastic net, RLS, Bayesian linear, Poisson, Least absolute deviations, Huber, Tukey, Least trimmed squares, Least median squares, Lp norm linear, SMA, Deming, Segmented, LOWESS, LOESS, spline, Naive Bayes, Gaussian process, Principal components, Partial least squares, Projection pursuit, Quantile regression, k-nearest neighbor, Radius neighbor, IDW, Nadaraya Watson, Priestley Chao, Gasser Muller, RBF Network, RVM, Decision tree, Random forest, Extra trees, GBDT, XGBoost, SVR, MLP, ELM, GMR, Isotonic, Ramer Douglas Peucker, Theil-Sen, Passing-Bablok, Repeated median |
127+
| regression | Least squares, Ridge, Lasso, Elastic net, RLS, Bayesian linear, Poisson, Least absolute deviations, Huber, Tukey, Least trimmed squares, Least median squares, Lp norm linear, SMA, Deming, Segmented, LOWESS, LOESS, spline, Naive Bayes, Gaussian process, Principal components, Partial least squares, Projection pursuit, Quantile regression, k-nearest neighbor, Radius neighbor, IDW, Nadaraya Watson, Priestley Chao, Gasser Muller, RBF Network, RVM, Decision tree, Random forest, Extra trees, GBDT, XGBoost, SVR, MARS, MLP, ELM, GMR, Isotonic, Ramer Douglas Peucker, Theil-Sen, Passing-Bablok, Repeated median |
128128
| interpolation | Nearest neighbor, IDW, (Spherical) Linear, Brahmagupta, Logarithmic, Cosine, (Inverse) Smoothstep, Cubic, (Centripetal) Catmull-Rom, Hermit, Polynomial, Lagrange, Trigonometric, Spline, RBF Network, Akima, Natural neighbor, Delaunay |
129129
| learning to rank | Ordered logistic, Ordered probit, PRank, OAP-BPM, RankNet |
130130
| anomaly detection | Percentile, MAD, Tukey's fences, Grubbs's test, Thompson test, Tietjen Moore test, Generalized ESD, Hotelling, MT, MCD, k-nearest neighbor, LOF, COF, ODIN, LDOF, INFLO, LOCI, LoOP, RDF, LDF, KDEOS, RDOS, NOF, RKOF, ABOD, PCA, OCSVM, KDE, GMM, Isolation forest, Autoencoder, GAN |

js/model_selector.js

+1
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ const AIMethods = [
319319
{ value: 'rbf', title: 'RBF Network' },
320320
{ value: 'rvm', title: 'RVM' },
321321
{ value: 'svr', title: 'Support vector regression' },
322+
{ value: 'mars', title: 'MARS' },
322323
{ value: 'mlp', title: 'Multi-layer perceptron' },
323324
{ value: 'elm', title: 'Extreme learning machine' },
324325
{ value: 'neuralnetwork', title: 'Neuralnetwork' },

js/view/mars.js

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import MARS from '../../lib/model/mars.js'
2+
import Controller from '../controller.js'
3+
4+
export default function (platform) {
5+
platform.setting.ml.usage = 'Click and add data point. Next, click "Fit" button.'
6+
platform.setting.ml.reference = {
7+
author: 'J. H. Friedman',
8+
title: 'MULTIVARIATE ADAPTIVE REGRESSION SPLINES',
9+
year: 1990,
10+
}
11+
const controller = new Controller(platform)
12+
const fitModel = () => {
13+
const model = new MARS(mmax.value)
14+
model.fit(platform.trainInput, platform.trainOutput)
15+
16+
const pred = model.predict(platform.testInput(2))
17+
platform.testResult(pred)
18+
}
19+
20+
const mmax = controller.input.number({ label: 'M max', max: 100, min: 1, value: 5 })
21+
22+
controller.input.button('Fit').on('click', fitModel)
23+
}

lib/model/mars.js

+139
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import Matrix from '../util/matrix.js'
2+
3+
class Term {
4+
constructor(s = [], t = [], v = []) {
5+
this._s = s
6+
this._t = t
7+
this._v = v
8+
}
9+
10+
prod(s, t, v) {
11+
return new Term(this._s.concat(s), this._t.concat(t), this._v.concat(v))
12+
}
13+
14+
calc(x) {
15+
let val = 1
16+
for (let i = 0; i < this._s.length; i++) {
17+
val *= Math.max(0, this._s[i] * (x[this._v[i]] - this._t[i]))
18+
}
19+
return val
20+
}
21+
}
22+
23+
/**
24+
* Multivariate Adaptive Regression Splines
25+
*/
26+
export default class MultivariateAdaptiveRegressionSplines {
27+
// Multivariate Adaptive Regression Splines
28+
// https://www.slac.stanford.edu/pubs/slacpubs/4750/slac-pub-4960.pdf
29+
// https://en.wikipedia.org/wiki/Multivariate_adaptive_regression_spline
30+
/**
31+
* @param {number} mmax Maximum number of terms
32+
*/
33+
constructor(mmax) {
34+
this._mmax = mmax
35+
this._b = [new Term()]
36+
this._a = null
37+
}
38+
39+
/**
40+
* Fit model.
41+
* @param {Array<Array<number>>} x Training data
42+
* @param {Array<Array<number>>} y Target values
43+
*/
44+
fit(x, y) {
45+
const n = x.length
46+
const d = x[0].length
47+
y = Matrix.fromArray(y)
48+
49+
let z = Matrix.ones(n, 1)
50+
let best_lof = Infinity
51+
let best_w = null
52+
while (this._b.length <= this._mmax) {
53+
let best_term = null
54+
let best_z = null
55+
for (let m = 0; m < this._b.length; m++) {
56+
for (let v = 0; v < d; v++) {
57+
for (let i = 0; i < n; i++) {
58+
if (this._b[m].calc(x[i]) === 0) continue
59+
const t = x[i][v]
60+
const termp = this._b[m].prod(1, t, v)
61+
const termm = this._b[m].prod(-1, t, v)
62+
const z1 = Matrix.resize(z, n, z.cols + 2)
63+
64+
for (let j = 0; j < n; j++) {
65+
z1.set(j, z1.cols - 2, termp.calc(x[j]))
66+
z1.set(j, z1.cols - 1, termm.calc(x[j]))
67+
}
68+
69+
const w = z1.tDot(z1).solve(z1.tDot(y))
70+
const yt = z1.dot(w)
71+
yt.sub(y)
72+
const e = yt.norm()
73+
if (e < best_lof) {
74+
best_term = { m, v, t }
75+
best_z = z1
76+
best_w = w
77+
best_lof = e
78+
}
79+
}
80+
}
81+
}
82+
83+
this._b.push(
84+
this._b[best_term.m].prod(1, best_term.t, best_term.v),
85+
this._b[best_term.m].prod(-1, best_term.t, best_term.v)
86+
)
87+
z = best_z
88+
this._a = best_w
89+
}
90+
91+
let best_w_b = this._b
92+
let best_k = z
93+
let best_k_b = this._b
94+
for (let i = this._b.length - 1; i >= 1; i--) {
95+
let b = Infinity
96+
const l = best_k
97+
const l_b = best_k_b
98+
for (let m = 1; m <= i; m++) {
99+
const z1 = l.copy()
100+
z1.remove(m, 1)
101+
const w = z1.tDot(z1).solve(z1.tDot(y))
102+
const yt = z1.dot(w)
103+
yt.sub(y)
104+
const e = yt.norm()
105+
106+
if (e < b) {
107+
b = e
108+
best_k = z1
109+
best_k_b = l_b.concat()
110+
best_k_b.splice(m, 1)
111+
}
112+
if (e < best_lof) {
113+
best_lof = e
114+
best_w = w
115+
best_w_b = l_b.concat()
116+
best_w_b.splice(m, 1)
117+
}
118+
}
119+
}
120+
this._a = best_w
121+
this._b = best_w_b
122+
}
123+
124+
/**
125+
* Returns predicted values.
126+
* @param {Array<Array<number>>} x Sample data
127+
* @returns {Array<Array<number>>} Predicted values
128+
*/
129+
predict(x) {
130+
const n = x.length
131+
const z = Matrix.ones(n, this._b.length)
132+
for (let i = 0; i < n; i++) {
133+
for (let m = 0; m < this._b.length; m++) {
134+
z.set(i, m, this._b[m].calc(x[i]))
135+
}
136+
}
137+
return z.dot(this._a).toArray()
138+
}
139+
}

tests/gui/view/mars.test.js

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import { getPage } from '../helper/browser'
2+
3+
describe('regression', () => {
4+
/** @type {Awaited<ReturnType<getPage>>} */
5+
let page
6+
beforeEach(async () => {
7+
page = await getPage()
8+
const taskSelectBox = await page.waitForSelector('#ml_selector dl:first-child dd:nth-child(5) select')
9+
await taskSelectBox.selectOption('RG')
10+
const modelSelectBox = await page.waitForSelector('#ml_selector .model_selection #mlDisp')
11+
await modelSelectBox.selectOption('mars')
12+
})
13+
14+
afterEach(async () => {
15+
await page?.close()
16+
})
17+
18+
test('initialize', async () => {
19+
const methodMenu = await page.waitForSelector('#ml_selector #method_menu')
20+
const buttons = await methodMenu.waitForSelector('.buttons')
21+
22+
const mmax = await buttons.waitForSelector('input:nth-of-type(1)')
23+
await expect(mmax.getAttribute('value')).resolves.toBe('5')
24+
})
25+
26+
test('learn', async () => {
27+
const methodMenu = await page.waitForSelector('#ml_selector #method_menu')
28+
const buttons = await methodMenu.waitForSelector('.buttons')
29+
30+
const methodFooter = await page.waitForSelector('#method_footer', { state: 'attached' })
31+
await expect(methodFooter.textContent()).resolves.toBe('')
32+
33+
const fitButton = await buttons.waitForSelector('input[value=Fit]')
34+
await fitButton.evaluate(el => el.click())
35+
36+
await expect(methodFooter.textContent()).resolves.toMatch(/^RMSE:[0-9.]+$/)
37+
})
38+
})

tests/lib/model/mars.test.js

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import Matrix from '../../../lib/util/matrix.js'
2+
import MARS from '../../../lib/model/mars.js'
3+
4+
import { rmse } from '../../../lib/evaluate/regression.js'
5+
6+
test('fit', () => {
7+
const model = new MARS(20)
8+
const x = Matrix.randn(50, 2, 0, 5).toArray()
9+
const t = []
10+
for (let i = 0; i < x.length; i++) {
11+
t[i] = [x[i][0] + x[i][1] + (Math.random() - 0.5) / 2 + 5]
12+
}
13+
model.fit(x, t)
14+
const y = model.predict(x)
15+
const err = rmse(y, t)[0]
16+
expect(err).toBeLessThan(0.5)
17+
})

0 commit comments

Comments
 (0)